diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index 1331badae..5b8f76a1e 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -11,7 +11,10 @@ int main(int argc, char **argv) { mlir::DialectRegistry registry; registry.insert(); + mlir::triton::gpu::TritonGPUDialect, + mlir::arith::ArithmeticDialect, + mlir::StandardOpsDialect, + mlir::scf::SCFDialect>(); return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "Triton (GPU) optimizer driver\n", registry) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index da5b09de5..a23c07c57 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -102,12 +102,12 @@ TritonGPUShardedEncodingAttr::parse(AsmParser &parser, Type type) { } void TritonGPUShardedEncodingAttr::print(mlir::AsmPrinter &printer) const { - printer << "<" + printer << "<{" << "threadTileSize = [" << getThreadTileSize() << "]" << ", warpTileSize = [" << getWarpTileSize() << "]" << ", blockTileSize = [" << getBlockTileSize() << "]" << ", order = [" << getOrder() << "]" - << ">"; + << "}>"; } Attribute @@ -175,11 +175,11 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) { } void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const { - printer << "<" + printer << "<{" << "vec = " << getVec() << ", perPhase = " << getPerPhase() << ", order = [" << getOrder() << "]" - << ">"; + << "}>"; } void TritonGPUDialect::initialize() { diff --git a/test/TritonGPU/layout.mlir b/test/TritonGPU/layout.mlir new file mode 100644 index 000000000..a5b3dc766 --- /dev/null +++ b/test/TritonGPU/layout.mlir @@ -0,0 +1,26 @@ +// RUN: triton-opt %s -split-input-file -verify-diagnostics + +#reg = #triton_gpu.sharded_layout<{ + threadTileSize = [1, 1], + warpTileSize = [32, 1], + blockTileSize = [64, 1], + order = [0, 1] +}> + +#reg2 = #triton_gpu.sharded_layout<{ + threadTileSize = [2, 1], + warpTileSize = [64, 1], + blockTileSize = [128, 1], + order = [0, 1] +}> + +func @add(%arg0: tensor<256xi32, #reg>, %arg1: tensor<256xi32, #reg>) { + %0 = arith.addi %arg0, %arg1 : tensor<256xi32, #reg> + return +} + +func @add(%arg0: tensor<256xi32, #reg>, %arg1: tensor<256xi32, #reg>) { // expected-note {{prior use here}} + // expected-error @+1 {{use of value '%arg0' expects different type than prior uses}} + %0 = arith.addi %arg0, %arg1 : tensor<256xi32, #reg2> + return +} \ No newline at end of file