Better textual representation
This commit is contained in:
@@ -156,7 +156,7 @@ def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]
|
||||
def TT_ReshapeOp : TT_Op<"reshape", [SameOperandsAndResultElementType]> {
|
||||
let summary = "reshape";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src, I64ArrayAttr:$shape);
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
}
|
||||
|
@@ -1192,7 +1192,7 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>().getElementType();
|
||||
return self.create<mlir::triton::ReshapeOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, argType), arg, self.getI64ArrayAttr(shape)
|
||||
loc, mlir::RankedTensorType::get(shape, argType), arg
|
||||
);
|
||||
})
|
||||
.def("create_cat", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
||||
|
@@ -38,10 +38,10 @@ module {
|
||||
%23 = "triton.broadcast"(%21) : (i32) -> tensor<64xi32>
|
||||
%24 = arith.addi %23, %22 : tensor<64xi32>
|
||||
%25 = triton.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
|
||||
%26 = "triton.reshape"(%20) {shape = [64, 1]} : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%26 = "triton.reshape"(%20) : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%27 = "triton.broadcast"(%arg6) : (i32) -> tensor<64x1xi32>
|
||||
%28 = arith.muli %26, %27 : tensor<64x1xi32>
|
||||
%29 = "triton.reshape"(%25) {shape = [1, 32]} : (tensor<32xi32>) -> tensor<1x32xi32>
|
||||
%29 = "triton.reshape"(%25) : (tensor<32xi32>) -> tensor<1x32xi32>
|
||||
%c1_i32_9 = arith.constant 1 : i32
|
||||
%30 = "triton.broadcast"(%c1_i32_9) : (i32) -> tensor<1x32xi32>
|
||||
%31 = arith.muli %29, %30 : tensor<1x32xi32>
|
||||
@@ -50,10 +50,10 @@ module {
|
||||
%34 = arith.addi %32, %33 : tensor<64x32xi32>
|
||||
%35 = "triton.broadcast"(%arg0) : (!triton.ptr<f16>) -> tensor<64x32x!triton.ptr<f16>>
|
||||
%36 = "triton.getelementptr"(%35, %34) : (tensor<64x32x!triton.ptr<f16>>, tensor<64x32xi32>) -> tensor<64x32x!triton.ptr<f16>>
|
||||
%37 = "triton.reshape"(%25) {shape = [32, 1]} : (tensor<32xi32>) -> tensor<32x1xi32>
|
||||
%37 = "triton.reshape"(%25) : (tensor<32xi32>) -> tensor<32x1xi32>
|
||||
%38 = "triton.broadcast"(%arg7) : (i32) -> tensor<32x1xi32>
|
||||
%39 = arith.muli %37, %38 : tensor<32x1xi32>
|
||||
%40 = "triton.reshape"(%24) {shape = [1, 64]} : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%40 = "triton.reshape"(%24) : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%c1_i32_10 = arith.constant 1 : i32
|
||||
%41 = "triton.broadcast"(%c1_i32_10) : (i32) -> tensor<1x64xi32>
|
||||
%42 = arith.muli %40, %41 : tensor<1x64xi32>
|
||||
@@ -100,22 +100,22 @@ module {
|
||||
%59 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%60 = "triton.broadcast"(%58) : (i32) -> tensor<64xi32>
|
||||
%61 = arith.addi %60, %59 : tensor<64xi32>
|
||||
%62 = "triton.reshape"(%57) {shape = [64, 1]} : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%62 = "triton.reshape"(%57) : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%63 = "triton.broadcast"(%arg8) : (i32) -> tensor<64x1xi32>
|
||||
%64 = arith.muli %63, %62 : tensor<64x1xi32>
|
||||
%65 = "triton.broadcast"(%arg2) : (!triton.ptr<f16>) -> tensor<64x1x!triton.ptr<f16>>
|
||||
%66 = "triton.getelementptr"(%65, %64) : (tensor<64x1x!triton.ptr<f16>>, tensor<64x1xi32>) -> tensor<64x1x!triton.ptr<f16>>
|
||||
%67 = "triton.reshape"(%61) {shape = [1, 64]} : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%67 = "triton.reshape"(%61) : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%c1_i32_13 = arith.constant 1 : i32
|
||||
%68 = "triton.broadcast"(%c1_i32_13) : (i32) -> tensor<1x64xi32>
|
||||
%69 = arith.muli %67, %68 : tensor<1x64xi32>
|
||||
%70 = "triton.broadcast"(%66) : (tensor<64x1x!triton.ptr<f16>>) -> tensor<64x64x!triton.ptr<f16>>
|
||||
%71 = "triton.broadcast"(%69) : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||
%72 = "triton.getelementptr"(%70, %71) : (tensor<64x64x!triton.ptr<f16>>, tensor<64x64xi32>) -> tensor<64x64x!triton.ptr<f16>>
|
||||
%73 = "triton.reshape"(%57) {shape = [64, 1]} : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%73 = "triton.reshape"(%57) : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%74 = "triton.broadcast"(%arg3) : (i32) -> tensor<64x1xi32>
|
||||
%75 = arith.cmpi slt, %73, %74 : tensor<64x1xi32>
|
||||
%76 = "triton.reshape"(%61) {shape = [1, 64]} : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%76 = "triton.reshape"(%61) : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%77 = "triton.broadcast"(%arg4) : (i32) -> tensor<1x64xi32>
|
||||
%78 = arith.cmpi slt, %76, %77 : tensor<1x64xi32>
|
||||
%79 = "triton.broadcast"(%75) : (tensor<64x1xi1>) -> tensor<64x64xi1>
|
||||
|
Reference in New Issue
Block a user