Better textual representation

This commit is contained in:
Yan Da
2022-04-07 20:44:41 +08:00
parent 0864b253bb
commit 6002340456
3 changed files with 10 additions and 10 deletions

View File

@@ -156,7 +156,7 @@ def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]
def TT_ReshapeOp : TT_Op<"reshape", [SameOperandsAndResultElementType]> {
let summary = "reshape";
let arguments = (ins TT_Tensor:$src, I64ArrayAttr:$shape);
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
}

View File

@@ -1192,7 +1192,7 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>().getElementType();
return self.create<mlir::triton::ReshapeOp>(
loc, mlir::RankedTensorType::get(shape, argType), arg, self.getI64ArrayAttr(shape)
loc, mlir::RankedTensorType::get(shape, argType), arg
);
})
.def("create_cat", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {

View File

@@ -38,10 +38,10 @@ module {
%23 = "triton.broadcast"(%21) : (i32) -> tensor<64xi32>
%24 = arith.addi %23, %22 : tensor<64xi32>
%25 = triton.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%26 = "triton.reshape"(%20) {shape = [64, 1]} : (tensor<64xi32>) -> tensor<64x1xi32>
%26 = "triton.reshape"(%20) : (tensor<64xi32>) -> tensor<64x1xi32>
%27 = "triton.broadcast"(%arg6) : (i32) -> tensor<64x1xi32>
%28 = arith.muli %26, %27 : tensor<64x1xi32>
%29 = "triton.reshape"(%25) {shape = [1, 32]} : (tensor<32xi32>) -> tensor<1x32xi32>
%29 = "triton.reshape"(%25) : (tensor<32xi32>) -> tensor<1x32xi32>
%c1_i32_9 = arith.constant 1 : i32
%30 = "triton.broadcast"(%c1_i32_9) : (i32) -> tensor<1x32xi32>
%31 = arith.muli %29, %30 : tensor<1x32xi32>
@@ -50,10 +50,10 @@ module {
%34 = arith.addi %32, %33 : tensor<64x32xi32>
%35 = "triton.broadcast"(%arg0) : (!triton.ptr<f16>) -> tensor<64x32x!triton.ptr<f16>>
%36 = "triton.getelementptr"(%35, %34) : (tensor<64x32x!triton.ptr<f16>>, tensor<64x32xi32>) -> tensor<64x32x!triton.ptr<f16>>
%37 = "triton.reshape"(%25) {shape = [32, 1]} : (tensor<32xi32>) -> tensor<32x1xi32>
%37 = "triton.reshape"(%25) : (tensor<32xi32>) -> tensor<32x1xi32>
%38 = "triton.broadcast"(%arg7) : (i32) -> tensor<32x1xi32>
%39 = arith.muli %37, %38 : tensor<32x1xi32>
%40 = "triton.reshape"(%24) {shape = [1, 64]} : (tensor<64xi32>) -> tensor<1x64xi32>
%40 = "triton.reshape"(%24) : (tensor<64xi32>) -> tensor<1x64xi32>
%c1_i32_10 = arith.constant 1 : i32
%41 = "triton.broadcast"(%c1_i32_10) : (i32) -> tensor<1x64xi32>
%42 = arith.muli %40, %41 : tensor<1x64xi32>
@@ -100,22 +100,22 @@ module {
%59 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%60 = "triton.broadcast"(%58) : (i32) -> tensor<64xi32>
%61 = arith.addi %60, %59 : tensor<64xi32>
%62 = "triton.reshape"(%57) {shape = [64, 1]} : (tensor<64xi32>) -> tensor<64x1xi32>
%62 = "triton.reshape"(%57) : (tensor<64xi32>) -> tensor<64x1xi32>
%63 = "triton.broadcast"(%arg8) : (i32) -> tensor<64x1xi32>
%64 = arith.muli %63, %62 : tensor<64x1xi32>
%65 = "triton.broadcast"(%arg2) : (!triton.ptr<f16>) -> tensor<64x1x!triton.ptr<f16>>
%66 = "triton.getelementptr"(%65, %64) : (tensor<64x1x!triton.ptr<f16>>, tensor<64x1xi32>) -> tensor<64x1x!triton.ptr<f16>>
%67 = "triton.reshape"(%61) {shape = [1, 64]} : (tensor<64xi32>) -> tensor<1x64xi32>
%67 = "triton.reshape"(%61) : (tensor<64xi32>) -> tensor<1x64xi32>
%c1_i32_13 = arith.constant 1 : i32
%68 = "triton.broadcast"(%c1_i32_13) : (i32) -> tensor<1x64xi32>
%69 = arith.muli %67, %68 : tensor<1x64xi32>
%70 = "triton.broadcast"(%66) : (tensor<64x1x!triton.ptr<f16>>) -> tensor<64x64x!triton.ptr<f16>>
%71 = "triton.broadcast"(%69) : (tensor<1x64xi32>) -> tensor<64x64xi32>
%72 = "triton.getelementptr"(%70, %71) : (tensor<64x64x!triton.ptr<f16>>, tensor<64x64xi32>) -> tensor<64x64x!triton.ptr<f16>>
%73 = "triton.reshape"(%57) {shape = [64, 1]} : (tensor<64xi32>) -> tensor<64x1xi32>
%73 = "triton.reshape"(%57) : (tensor<64xi32>) -> tensor<64x1xi32>
%74 = "triton.broadcast"(%arg3) : (i32) -> tensor<64x1xi32>
%75 = arith.cmpi slt, %73, %74 : tensor<64x1xi32>
%76 = "triton.reshape"(%61) {shape = [1, 64]} : (tensor<64xi32>) -> tensor<1x64xi32>
%76 = "triton.reshape"(%61) : (tensor<64xi32>) -> tensor<1x64xi32>
%77 = "triton.broadcast"(%arg4) : (i32) -> tensor<1x64xi32>
%78 = arith.cmpi slt, %76, %77 : tensor<1x64xi32>
%79 = "triton.broadcast"(%75) : (tensor<64x1xi1>) -> tensor<64x64xi1>