more TritonGPU unit tests
This commit is contained in:
@@ -11,7 +11,10 @@ int main(int argc, char **argv) {
|
|||||||
|
|
||||||
mlir::DialectRegistry registry;
|
mlir::DialectRegistry registry;
|
||||||
registry.insert<mlir::triton::TritonDialect,
|
registry.insert<mlir::triton::TritonDialect,
|
||||||
mlir::triton::gpu::TritonGPUDialect>();
|
mlir::triton::gpu::TritonGPUDialect,
|
||||||
|
mlir::arith::ArithmeticDialect,
|
||||||
|
mlir::StandardOpsDialect,
|
||||||
|
mlir::scf::SCFDialect>();
|
||||||
|
|
||||||
return mlir::asMainReturnCode(
|
return mlir::asMainReturnCode(
|
||||||
mlir::MlirOptMain(argc, argv, "Triton (GPU) optimizer driver\n", registry)
|
mlir::MlirOptMain(argc, argv, "Triton (GPU) optimizer driver\n", registry)
|
||||||
|
@@ -102,12 +102,12 @@ TritonGPUShardedEncodingAttr::parse(AsmParser &parser, Type type) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TritonGPUShardedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
void TritonGPUShardedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||||
printer << "<"
|
printer << "<{"
|
||||||
<< "threadTileSize = [" << getThreadTileSize() << "]"
|
<< "threadTileSize = [" << getThreadTileSize() << "]"
|
||||||
<< ", warpTileSize = [" << getWarpTileSize() << "]"
|
<< ", warpTileSize = [" << getWarpTileSize() << "]"
|
||||||
<< ", blockTileSize = [" << getBlockTileSize() << "]"
|
<< ", blockTileSize = [" << getBlockTileSize() << "]"
|
||||||
<< ", order = [" << getOrder() << "]"
|
<< ", order = [" << getOrder() << "]"
|
||||||
<< ">";
|
<< "}>";
|
||||||
}
|
}
|
||||||
|
|
||||||
Attribute
|
Attribute
|
||||||
@@ -175,11 +175,11 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
|
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||||
printer << "<"
|
printer << "<{"
|
||||||
<< "vec = " << getVec()
|
<< "vec = " << getVec()
|
||||||
<< ", perPhase = " << getPerPhase()
|
<< ", perPhase = " << getPerPhase()
|
||||||
<< ", order = [" << getOrder() << "]"
|
<< ", order = [" << getOrder() << "]"
|
||||||
<< ">";
|
<< "}>";
|
||||||
}
|
}
|
||||||
|
|
||||||
void TritonGPUDialect::initialize() {
|
void TritonGPUDialect::initialize() {
|
||||||
|
26
test/TritonGPU/layout.mlir
Normal file
26
test/TritonGPU/layout.mlir
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
// RUN: triton-opt %s -split-input-file -verify-diagnostics
|
||||||
|
|
||||||
|
#reg = #triton_gpu.sharded_layout<{
|
||||||
|
threadTileSize = [1, 1],
|
||||||
|
warpTileSize = [32, 1],
|
||||||
|
blockTileSize = [64, 1],
|
||||||
|
order = [0, 1]
|
||||||
|
}>
|
||||||
|
|
||||||
|
#reg2 = #triton_gpu.sharded_layout<{
|
||||||
|
threadTileSize = [2, 1],
|
||||||
|
warpTileSize = [64, 1],
|
||||||
|
blockTileSize = [128, 1],
|
||||||
|
order = [0, 1]
|
||||||
|
}>
|
||||||
|
|
||||||
|
func @add(%arg0: tensor<256xi32, #reg>, %arg1: tensor<256xi32, #reg>) {
|
||||||
|
%0 = arith.addi %arg0, %arg1 : tensor<256xi32, #reg>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func @add(%arg0: tensor<256xi32, #reg>, %arg1: tensor<256xi32, #reg>) { // expected-note {{prior use here}}
|
||||||
|
// expected-error @+1 {{use of value '%arg0' expects different type than prior uses}}
|
||||||
|
%0 = arith.addi %arg0, %arg1 : tensor<256xi32, #reg2>
|
||||||
|
return
|
||||||
|
}
|
Reference in New Issue
Block a user