.
This commit is contained in:
@@ -41,13 +41,6 @@ public:
|
|||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ModuleOp m = getOperation();
|
ModuleOp m = getOperation();
|
||||||
// Move convert(load) immediately after dependent load
|
|
||||||
m.walk([&](triton::gpu::ConvertLayoutOp op){
|
|
||||||
auto load = dyn_cast<triton::LoadOp>(op.getOperand().getDefiningOp());
|
|
||||||
if(!load)
|
|
||||||
return;
|
|
||||||
op->moveAfter(load);
|
|
||||||
});
|
|
||||||
// Sink conversions into loops when they will increase
|
// Sink conversions into loops when they will increase
|
||||||
// register pressure
|
// register pressure
|
||||||
DenseMap<Operation*, Operation *> opToMove;
|
DenseMap<Operation*, Operation *> opToMove;
|
||||||
@@ -62,7 +55,18 @@ public:
|
|||||||
});
|
});
|
||||||
for(auto &kv: opToMove)
|
for(auto &kv: opToMove)
|
||||||
kv.first->moveBefore(kv.second);
|
kv.first->moveBefore(kv.second);
|
||||||
|
// Move convert(load) immediately after dependent load
|
||||||
|
m.walk([&](triton::gpu::ConvertLayoutOp op){
|
||||||
|
auto dstType = op.getResult().getType().cast<RankedTensorType>();
|
||||||
|
auto dstEncoding = dstType.getEncoding();
|
||||||
|
if(!dstEncoding.isa<triton::gpu::SharedEncodingAttr>())
|
||||||
|
return;
|
||||||
|
Operation* argOp = op.getOperand().getDefiningOp();
|
||||||
|
if(!argOp)
|
||||||
|
return;
|
||||||
|
llvm::outs() << "moving " << *op << "\n";
|
||||||
|
op->moveAfter(argOp);
|
||||||
|
});
|
||||||
// Move transpositions just after their definition
|
// Move transpositions just after their definition
|
||||||
opToMove.clear();
|
opToMove.clear();
|
||||||
m.walk([&](triton::TransOp op){
|
m.walk([&](triton::TransOp op){
|
||||||
|
@@ -7,14 +7,13 @@
|
|||||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
|
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
||||||
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
|
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
|
||||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
|
|
||||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
|
|
||||||
%cst_1 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
|
|
||||||
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
|
|
||||||
%c0 = arith.constant 0 : index
|
%c0 = arith.constant 0 : index
|
||||||
%c1 = arith.constant 1 : index
|
%c1 = arith.constant 1 : index
|
||||||
%c128_i32 = arith.constant 128 : i32
|
%c128_i32 = arith.constant 128 : i32
|
||||||
%c128 = arith.constant 128 : index
|
%c128 = arith.constant 128 : index
|
||||||
|
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
|
||||||
|
%cst_0 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
|
||||||
|
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
|
||||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||||
%1 = arith.divsi %0, %arg22 : i32
|
%1 = arith.divsi %0, %arg22 : i32
|
||||||
%2 = arith.remsi %0, %arg22 : i32
|
%2 = arith.remsi %0, %arg22 : i32
|
||||||
@@ -77,92 +76,92 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|||||||
%58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
%59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
||||||
%60 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
%60 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
||||||
%61 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1>
|
%61 = tt.trans %60 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
||||||
%62 = tt.broadcast %61 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
%62 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1>
|
||||||
%63 = arith.addi %62, %24 : tensor<128x64xi32, #blocked1>
|
%63 = tt.broadcast %62 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
||||||
%64 = tt.addptr %30, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%64 = arith.addi %63, %24 : tensor<128x64xi32, #blocked1>
|
||||||
%65 = tt.load %64 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
%65 = tt.addptr %30, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%66 = triton_gpu.convert_layout %65 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
%66 = tt.load %65 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
||||||
%67 = arith.index_cast %47 : i32 to index
|
%67 = triton_gpu.convert_layout %66 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
||||||
%68 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
%68 = tt.trans %67 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
||||||
%69 = tt.expand_dims %68 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
|
%69 = arith.index_cast %47 : i32 to index
|
||||||
%70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
%70 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
||||||
%71 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
|
%71 = tt.expand_dims %70 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
|
||||||
%72 = tt.broadcast %71 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
%72 = tt.broadcast %71 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
||||||
%73 = arith.addi %72, %26 : tensor<128x64xi32, #blocked2>
|
%73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
|
||||||
%74 = tt.addptr %32, %73 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
%74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
||||||
%75 = tt.addptr %27, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2>
|
||||||
%76 = tt.addptr %31, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
||||||
%77:5 = scf.for %arg26 = %67 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %74, %arg30 = %75, %arg31 = %76) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
|
%77 = tt.addptr %27, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%84 = arith.index_cast %arg26 : index to i32
|
%78 = tt.addptr %31, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%85 = tt.splat %84 : (i32) -> tensor<128xi32, #blocked0>
|
%79:5 = scf.for %arg26 = %69 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
|
||||||
%86 = tt.splat %84 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
%86 = arith.index_cast %arg26 : index to i32
|
||||||
%87 = arith.addi %85, %14 : tensor<128xi32, #blocked0>
|
%87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0>
|
||||||
%88 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
%88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
||||||
%89 = triton_gpu.convert_layout %88 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
%89 = arith.addi %87, %14 : tensor<128xi32, #blocked0>
|
||||||
%90 = triton_gpu.convert_layout %89 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
%90 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
||||||
%91 = tt.trans %60 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
%91 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
||||||
%92 = triton_gpu.convert_layout %91 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
%92 = triton_gpu.convert_layout %61 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
||||||
%93 = tt.dot %90, %92, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
%93 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
||||||
%94 = arith.addi %86, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
%94 = tt.dot %93, %92, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
||||||
%95 = tt.expand_dims %94 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
|
%95 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
||||||
%96 = tt.broadcast %95 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
%96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
|
||||||
%97 = "triton_gpu.cmpi"(%96, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
|
%97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
||||||
%98 = "triton_gpu.select"(%97, %93, %cst_1) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
%98 = "triton_gpu.cmpi"(%97, %72) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
|
||||||
%99 = tt.addptr %38, %87 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
%99 = "triton_gpu.select"(%98, %94, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
||||||
%100 = tt.load %99 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
%100 = tt.addptr %38, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
||||||
%101 = triton_gpu.convert_layout %100 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
%101 = tt.load %100 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
||||||
%102 = arith.mulf %98, %39 : tensor<128x128xf32, #mma0>
|
%102 = arith.mulf %99, %39 : tensor<128x128xf32, #mma0>
|
||||||
%103 = tt.expand_dims %101 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
%103 = triton_gpu.convert_layout %101 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
||||||
%104 = tt.broadcast %103 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
%104 = tt.expand_dims %103 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
||||||
%105 = arith.subf %102, %104 : tensor<128x128xf32, #mma0>
|
%105 = tt.broadcast %104 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
||||||
%106 = math.exp %105 : tensor<128x128xf32, #mma0>
|
%106 = arith.subf %102, %105 : tensor<128x128xf32, #mma0>
|
||||||
%107 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
%107 = math.exp %106 : tensor<128x128xf32, #mma0>
|
||||||
%108 = triton_gpu.convert_layout %107 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
%108 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
||||||
%109 = arith.truncf %106 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
%109 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
||||||
%110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
%110 = arith.truncf %107 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
||||||
%111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
%111 = triton_gpu.convert_layout %110 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
||||||
%112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
%112 = tt.trans %111 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
||||||
%113 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
%113 = triton_gpu.convert_layout %112 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
||||||
%114 = tt.dot %112, %113, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
%114 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
||||||
%115 = tt.addptr %40, %87 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
%115 = tt.dot %113, %114, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
||||||
%116 = tt.load %115 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
%116 = tt.addptr %40, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
||||||
%117 = triton_gpu.convert_layout %116 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
%117 = tt.load %116 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
||||||
%118 = tt.expand_dims %117 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
%118 = triton_gpu.convert_layout %117 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
||||||
%119 = tt.broadcast %118 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
%119 = tt.expand_dims %118 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
||||||
%120 = arith.subf %cst_0, %119 : tensor<128x128xf32, #mma0>
|
%120 = tt.broadcast %119 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
||||||
%121 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
%121 = arith.subf %cst_1, %120 : tensor<128x128xf32, #mma0>
|
||||||
%122 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
%122 = triton_gpu.convert_layout %68 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
||||||
%123 = triton_gpu.convert_layout %121 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
%123 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
||||||
%124 = tt.dot %122, %123, %120 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
%124 = tt.dot %123, %122, %121 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
||||||
%125 = arith.mulf %106, %124 : tensor<128x128xf32, #mma0>
|
%125 = arith.mulf %107, %124 : tensor<128x128xf32, #mma0>
|
||||||
%126 = arith.mulf %125, %39 : tensor<128x128xf32, #mma0>
|
%126 = arith.mulf %125, %39 : tensor<128x128xf32, #mma0>
|
||||||
%127 = arith.truncf %126 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
%127 = arith.truncf %126 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
||||||
%128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
%128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
||||||
%129 = tt.trans %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
%129 = tt.trans %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
||||||
%130 = triton_gpu.convert_layout %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
%130 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
||||||
%131 = triton_gpu.convert_layout %89 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
%131 = triton_gpu.convert_layout %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
||||||
%132 = tt.dot %130, %131, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
%132 = tt.dot %131, %130, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
||||||
%133 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
|
%133 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
|
||||||
%134 = triton_gpu.convert_layout %133 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
|
%134 = triton_gpu.convert_layout %133 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
|
||||||
%135 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
%135 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
||||||
%136 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
%136 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
||||||
%137 = tt.dot %135, %136, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
%137 = tt.dot %136, %135, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
||||||
%138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
|
%138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
|
||||||
tt.store %arg29, %138 : tensor<128x64xf32, #blocked2>
|
tt.store %arg29, %138 : tensor<128x64xf32, #blocked2>
|
||||||
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
||||||
%140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
scf.yield %114, %132, %139, %140, %141 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
scf.yield %115, %132, %139, %140, %141 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
}
|
}
|
||||||
%78 = arith.truncf %77#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
%80 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
||||||
%79 = tt.addptr %44, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%81 = tt.addptr %44, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%80 = triton_gpu.convert_layout %78 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
%82 = triton_gpu.convert_layout %80 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
||||||
tt.store %79, %80 : tensor<128x64xf16, #blocked1>
|
tt.store %81, %82 : tensor<128x64xf16, #blocked1>
|
||||||
%81 = arith.truncf %77#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
%83 = arith.truncf %79#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
||||||
%82 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%84 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
%85 = triton_gpu.convert_layout %83 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
||||||
tt.store %82, %83 : tensor<128x64xf16, #blocked1>
|
tt.store %84, %85 : tensor<128x64xf16, #blocked1>
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@@ -906,10 +906,10 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
|
|||||||
pm.add_tritongpu_combine_pass(compute_capability)
|
pm.add_tritongpu_combine_pass(compute_capability)
|
||||||
pm.add_cse_pass()
|
pm.add_cse_pass()
|
||||||
# pm.add_tritongpu_optimize_load_convert_pass()
|
# pm.add_tritongpu_optimize_load_convert_pass()
|
||||||
pm.add_tritongpu_sink_conversions_from_shared_pass()
|
|
||||||
pm.add_tritongpu_decompose_conversions_to_dot_operand_pass()
|
pm.add_tritongpu_decompose_conversions_to_dot_operand_pass()
|
||||||
pm.add_cse_pass()
|
pm.add_cse_pass()
|
||||||
pm.add_symbol_dce_pass()
|
pm.add_symbol_dce_pass()
|
||||||
|
pm.add_tritongpu_sink_conversions_from_shared_pass()
|
||||||
pm.run(mod)
|
pm.run(mod)
|
||||||
return mod
|
return mod
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user