From 60023404560538e793566c8ff40ab55ef4f26a7b Mon Sep 17 00:00:00 2001 From: Yan Da Date: Thu, 7 Apr 2022 20:44:41 +0800 Subject: [PATCH] Better textual representation --- include/triton/ir/TritonOps.td | 2 +- python/src/triton.cc | 2 +- rewrite-test/jit/matmul/matmul.mlir | 16 ++++++++-------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/include/triton/ir/TritonOps.td b/include/triton/ir/TritonOps.td index 0b4516ff5..36fbdd2d1 100644 --- a/include/triton/ir/TritonOps.td +++ b/include/triton/ir/TritonOps.td @@ -156,7 +156,7 @@ def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape] def TT_ReshapeOp : TT_Op<"reshape", [SameOperandsAndResultElementType]> { let summary = "reshape"; - let arguments = (ins TT_Tensor:$src, I64ArrayAttr:$shape); + let arguments = (ins TT_Tensor:$src); let results = (outs TT_Tensor:$result); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 57c0e875e..0a1ecff73 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1192,7 +1192,7 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); auto argType = arg.getType().dyn_cast().getElementType(); return self.create( - loc, mlir::RankedTensorType::get(shape, argType), arg, self.getI64ArrayAttr(shape) + loc, mlir::RankedTensorType::get(shape, argType), arg ); }) .def("create_cat", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { diff --git a/rewrite-test/jit/matmul/matmul.mlir b/rewrite-test/jit/matmul/matmul.mlir index fe392ed57..81b8a1bcf 100644 --- a/rewrite-test/jit/matmul/matmul.mlir +++ b/rewrite-test/jit/matmul/matmul.mlir @@ -38,10 +38,10 @@ module { %23 = "triton.broadcast"(%21) : (i32) -> tensor<64xi32> %24 = arith.addi %23, %22 : tensor<64xi32> %25 = triton.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %26 = "triton.reshape"(%20) {shape = [64, 1]} : (tensor<64xi32>) -> tensor<64x1xi32> + %26 = "triton.reshape"(%20) : (tensor<64xi32>) -> tensor<64x1xi32> %27 = "triton.broadcast"(%arg6) : (i32) -> tensor<64x1xi32> %28 = arith.muli %26, %27 : tensor<64x1xi32> - %29 = "triton.reshape"(%25) {shape = [1, 32]} : (tensor<32xi32>) -> tensor<1x32xi32> + %29 = "triton.reshape"(%25) : (tensor<32xi32>) -> tensor<1x32xi32> %c1_i32_9 = arith.constant 1 : i32 %30 = "triton.broadcast"(%c1_i32_9) : (i32) -> tensor<1x32xi32> %31 = arith.muli %29, %30 : tensor<1x32xi32> @@ -50,10 +50,10 @@ module { %34 = arith.addi %32, %33 : tensor<64x32xi32> %35 = "triton.broadcast"(%arg0) : (!triton.ptr) -> tensor<64x32x!triton.ptr> %36 = "triton.getelementptr"(%35, %34) : (tensor<64x32x!triton.ptr>, tensor<64x32xi32>) -> tensor<64x32x!triton.ptr> - %37 = "triton.reshape"(%25) {shape = [32, 1]} : (tensor<32xi32>) -> tensor<32x1xi32> + %37 = "triton.reshape"(%25) : (tensor<32xi32>) -> tensor<32x1xi32> %38 = "triton.broadcast"(%arg7) : (i32) -> tensor<32x1xi32> %39 = arith.muli %37, %38 : tensor<32x1xi32> - %40 = "triton.reshape"(%24) {shape = [1, 64]} : (tensor<64xi32>) -> tensor<1x64xi32> + %40 = "triton.reshape"(%24) : (tensor<64xi32>) -> tensor<1x64xi32> %c1_i32_10 = arith.constant 1 : i32 %41 = "triton.broadcast"(%c1_i32_10) : (i32) -> tensor<1x64xi32> %42 = arith.muli %40, %41 : tensor<1x64xi32> @@ -100,22 +100,22 @@ module { %59 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> %60 = "triton.broadcast"(%58) : (i32) -> tensor<64xi32> %61 = arith.addi %60, %59 : tensor<64xi32> - %62 = "triton.reshape"(%57) {shape = [64, 1]} : (tensor<64xi32>) -> tensor<64x1xi32> + %62 = "triton.reshape"(%57) : (tensor<64xi32>) -> tensor<64x1xi32> %63 = "triton.broadcast"(%arg8) : (i32) -> tensor<64x1xi32> %64 = arith.muli %63, %62 : tensor<64x1xi32> %65 = "triton.broadcast"(%arg2) : (!triton.ptr) -> tensor<64x1x!triton.ptr> %66 = "triton.getelementptr"(%65, %64) : (tensor<64x1x!triton.ptr>, tensor<64x1xi32>) -> tensor<64x1x!triton.ptr> - %67 = "triton.reshape"(%61) {shape = [1, 64]} : (tensor<64xi32>) -> tensor<1x64xi32> + %67 = "triton.reshape"(%61) : (tensor<64xi32>) -> tensor<1x64xi32> %c1_i32_13 = arith.constant 1 : i32 %68 = "triton.broadcast"(%c1_i32_13) : (i32) -> tensor<1x64xi32> %69 = arith.muli %67, %68 : tensor<1x64xi32> %70 = "triton.broadcast"(%66) : (tensor<64x1x!triton.ptr>) -> tensor<64x64x!triton.ptr> %71 = "triton.broadcast"(%69) : (tensor<1x64xi32>) -> tensor<64x64xi32> %72 = "triton.getelementptr"(%70, %71) : (tensor<64x64x!triton.ptr>, tensor<64x64xi32>) -> tensor<64x64x!triton.ptr> - %73 = "triton.reshape"(%57) {shape = [64, 1]} : (tensor<64xi32>) -> tensor<64x1xi32> + %73 = "triton.reshape"(%57) : (tensor<64xi32>) -> tensor<64x1xi32> %74 = "triton.broadcast"(%arg3) : (i32) -> tensor<64x1xi32> %75 = arith.cmpi slt, %73, %74 : tensor<64x1xi32> - %76 = "triton.reshape"(%61) {shape = [1, 64]} : (tensor<64xi32>) -> tensor<1x64xi32> + %76 = "triton.reshape"(%61) : (tensor<64xi32>) -> tensor<1x64xi32> %77 = "triton.broadcast"(%arg4) : (i32) -> tensor<1x64xi32> %78 = arith.cmpi slt, %76, %77 : tensor<1x64xi32> %79 = "triton.broadcast"(%75) : (tensor<64x1xi1>) -> tensor<64x64xi1>