Better textual representation
This commit is contained in:
@@ -156,7 +156,7 @@ def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]
|
|||||||
def TT_ReshapeOp : TT_Op<"reshape", [SameOperandsAndResultElementType]> {
|
def TT_ReshapeOp : TT_Op<"reshape", [SameOperandsAndResultElementType]> {
|
||||||
let summary = "reshape";
|
let summary = "reshape";
|
||||||
|
|
||||||
let arguments = (ins TT_Tensor:$src, I64ArrayAttr:$shape);
|
let arguments = (ins TT_Tensor:$src);
|
||||||
|
|
||||||
let results = (outs TT_Tensor:$result);
|
let results = (outs TT_Tensor:$result);
|
||||||
}
|
}
|
||||||
|
@@ -1192,7 +1192,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>().getElementType();
|
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>().getElementType();
|
||||||
return self.create<mlir::triton::ReshapeOp>(
|
return self.create<mlir::triton::ReshapeOp>(
|
||||||
loc, mlir::RankedTensorType::get(shape, argType), arg, self.getI64ArrayAttr(shape)
|
loc, mlir::RankedTensorType::get(shape, argType), arg
|
||||||
);
|
);
|
||||||
})
|
})
|
||||||
.def("create_cat", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
.def("create_cat", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
||||||
|
@@ -38,10 +38,10 @@ module {
|
|||||||
%23 = "triton.broadcast"(%21) : (i32) -> tensor<64xi32>
|
%23 = "triton.broadcast"(%21) : (i32) -> tensor<64xi32>
|
||||||
%24 = arith.addi %23, %22 : tensor<64xi32>
|
%24 = arith.addi %23, %22 : tensor<64xi32>
|
||||||
%25 = triton.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
|
%25 = triton.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
|
||||||
%26 = "triton.reshape"(%20) {shape = [64, 1]} : (tensor<64xi32>) -> tensor<64x1xi32>
|
%26 = "triton.reshape"(%20) : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||||
%27 = "triton.broadcast"(%arg6) : (i32) -> tensor<64x1xi32>
|
%27 = "triton.broadcast"(%arg6) : (i32) -> tensor<64x1xi32>
|
||||||
%28 = arith.muli %26, %27 : tensor<64x1xi32>
|
%28 = arith.muli %26, %27 : tensor<64x1xi32>
|
||||||
%29 = "triton.reshape"(%25) {shape = [1, 32]} : (tensor<32xi32>) -> tensor<1x32xi32>
|
%29 = "triton.reshape"(%25) : (tensor<32xi32>) -> tensor<1x32xi32>
|
||||||
%c1_i32_9 = arith.constant 1 : i32
|
%c1_i32_9 = arith.constant 1 : i32
|
||||||
%30 = "triton.broadcast"(%c1_i32_9) : (i32) -> tensor<1x32xi32>
|
%30 = "triton.broadcast"(%c1_i32_9) : (i32) -> tensor<1x32xi32>
|
||||||
%31 = arith.muli %29, %30 : tensor<1x32xi32>
|
%31 = arith.muli %29, %30 : tensor<1x32xi32>
|
||||||
@@ -50,10 +50,10 @@ module {
|
|||||||
%34 = arith.addi %32, %33 : tensor<64x32xi32>
|
%34 = arith.addi %32, %33 : tensor<64x32xi32>
|
||||||
%35 = "triton.broadcast"(%arg0) : (!triton.ptr<f16>) -> tensor<64x32x!triton.ptr<f16>>
|
%35 = "triton.broadcast"(%arg0) : (!triton.ptr<f16>) -> tensor<64x32x!triton.ptr<f16>>
|
||||||
%36 = "triton.getelementptr"(%35, %34) : (tensor<64x32x!triton.ptr<f16>>, tensor<64x32xi32>) -> tensor<64x32x!triton.ptr<f16>>
|
%36 = "triton.getelementptr"(%35, %34) : (tensor<64x32x!triton.ptr<f16>>, tensor<64x32xi32>) -> tensor<64x32x!triton.ptr<f16>>
|
||||||
%37 = "triton.reshape"(%25) {shape = [32, 1]} : (tensor<32xi32>) -> tensor<32x1xi32>
|
%37 = "triton.reshape"(%25) : (tensor<32xi32>) -> tensor<32x1xi32>
|
||||||
%38 = "triton.broadcast"(%arg7) : (i32) -> tensor<32x1xi32>
|
%38 = "triton.broadcast"(%arg7) : (i32) -> tensor<32x1xi32>
|
||||||
%39 = arith.muli %37, %38 : tensor<32x1xi32>
|
%39 = arith.muli %37, %38 : tensor<32x1xi32>
|
||||||
%40 = "triton.reshape"(%24) {shape = [1, 64]} : (tensor<64xi32>) -> tensor<1x64xi32>
|
%40 = "triton.reshape"(%24) : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||||
%c1_i32_10 = arith.constant 1 : i32
|
%c1_i32_10 = arith.constant 1 : i32
|
||||||
%41 = "triton.broadcast"(%c1_i32_10) : (i32) -> tensor<1x64xi32>
|
%41 = "triton.broadcast"(%c1_i32_10) : (i32) -> tensor<1x64xi32>
|
||||||
%42 = arith.muli %40, %41 : tensor<1x64xi32>
|
%42 = arith.muli %40, %41 : tensor<1x64xi32>
|
||||||
@@ -100,22 +100,22 @@ module {
|
|||||||
%59 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
%59 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||||
%60 = "triton.broadcast"(%58) : (i32) -> tensor<64xi32>
|
%60 = "triton.broadcast"(%58) : (i32) -> tensor<64xi32>
|
||||||
%61 = arith.addi %60, %59 : tensor<64xi32>
|
%61 = arith.addi %60, %59 : tensor<64xi32>
|
||||||
%62 = "triton.reshape"(%57) {shape = [64, 1]} : (tensor<64xi32>) -> tensor<64x1xi32>
|
%62 = "triton.reshape"(%57) : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||||
%63 = "triton.broadcast"(%arg8) : (i32) -> tensor<64x1xi32>
|
%63 = "triton.broadcast"(%arg8) : (i32) -> tensor<64x1xi32>
|
||||||
%64 = arith.muli %63, %62 : tensor<64x1xi32>
|
%64 = arith.muli %63, %62 : tensor<64x1xi32>
|
||||||
%65 = "triton.broadcast"(%arg2) : (!triton.ptr<f16>) -> tensor<64x1x!triton.ptr<f16>>
|
%65 = "triton.broadcast"(%arg2) : (!triton.ptr<f16>) -> tensor<64x1x!triton.ptr<f16>>
|
||||||
%66 = "triton.getelementptr"(%65, %64) : (tensor<64x1x!triton.ptr<f16>>, tensor<64x1xi32>) -> tensor<64x1x!triton.ptr<f16>>
|
%66 = "triton.getelementptr"(%65, %64) : (tensor<64x1x!triton.ptr<f16>>, tensor<64x1xi32>) -> tensor<64x1x!triton.ptr<f16>>
|
||||||
%67 = "triton.reshape"(%61) {shape = [1, 64]} : (tensor<64xi32>) -> tensor<1x64xi32>
|
%67 = "triton.reshape"(%61) : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||||
%c1_i32_13 = arith.constant 1 : i32
|
%c1_i32_13 = arith.constant 1 : i32
|
||||||
%68 = "triton.broadcast"(%c1_i32_13) : (i32) -> tensor<1x64xi32>
|
%68 = "triton.broadcast"(%c1_i32_13) : (i32) -> tensor<1x64xi32>
|
||||||
%69 = arith.muli %67, %68 : tensor<1x64xi32>
|
%69 = arith.muli %67, %68 : tensor<1x64xi32>
|
||||||
%70 = "triton.broadcast"(%66) : (tensor<64x1x!triton.ptr<f16>>) -> tensor<64x64x!triton.ptr<f16>>
|
%70 = "triton.broadcast"(%66) : (tensor<64x1x!triton.ptr<f16>>) -> tensor<64x64x!triton.ptr<f16>>
|
||||||
%71 = "triton.broadcast"(%69) : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
%71 = "triton.broadcast"(%69) : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||||
%72 = "triton.getelementptr"(%70, %71) : (tensor<64x64x!triton.ptr<f16>>, tensor<64x64xi32>) -> tensor<64x64x!triton.ptr<f16>>
|
%72 = "triton.getelementptr"(%70, %71) : (tensor<64x64x!triton.ptr<f16>>, tensor<64x64xi32>) -> tensor<64x64x!triton.ptr<f16>>
|
||||||
%73 = "triton.reshape"(%57) {shape = [64, 1]} : (tensor<64xi32>) -> tensor<64x1xi32>
|
%73 = "triton.reshape"(%57) : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||||
%74 = "triton.broadcast"(%arg3) : (i32) -> tensor<64x1xi32>
|
%74 = "triton.broadcast"(%arg3) : (i32) -> tensor<64x1xi32>
|
||||||
%75 = arith.cmpi slt, %73, %74 : tensor<64x1xi32>
|
%75 = arith.cmpi slt, %73, %74 : tensor<64x1xi32>
|
||||||
%76 = "triton.reshape"(%61) {shape = [1, 64]} : (tensor<64xi32>) -> tensor<1x64xi32>
|
%76 = "triton.reshape"(%61) : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||||
%77 = "triton.broadcast"(%arg4) : (i32) -> tensor<1x64xi32>
|
%77 = "triton.broadcast"(%arg4) : (i32) -> tensor<1x64xi32>
|
||||||
%78 = arith.cmpi slt, %76, %77 : tensor<1x64xi32>
|
%78 = arith.cmpi slt, %76, %77 : tensor<1x64xi32>
|
||||||
%79 = "triton.broadcast"(%75) : (tensor<64x1xi1>) -> tensor<64x64xi1>
|
%79 = "triton.broadcast"(%75) : (tensor<64x1xi1>) -> tensor<64x64xi1>
|
||||||
|
Reference in New Issue
Block a user