This commit is contained in:
Phil Tillet
2023-01-09 18:45:44 -08:00
parent 0c101e0c33
commit fc1007278d
5 changed files with 137 additions and 136 deletions

View File

@@ -63,15 +63,14 @@ public:
for(auto &kv: opToMove)
kv.first->moveBefore(kv.second);
// Move transpositions just before their first use
// Move transpositions just after their definition
opToMove.clear();
m.walk([&](triton::TransOp op){
auto user_begin = op->user_begin();
opToMove.insert({op, *user_begin});
Operation* argOp =op.getOperand().getDefiningOp();
if(!argOp)
return;
op->moveAfter(argOp);
});
for(auto &kv: opToMove)
kv.first->moveBefore(kv.second);
return;
}

View File

@@ -10,6 +10,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
%cst_0 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
%cst_10 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
%c128 = arith.constant 128 : index
%c128_i32 = arith.constant 128 : i32
%c1 = arith.constant 1 : index
@@ -121,9 +122,9 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%109 = arith.truncf %106 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
%111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
%112 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%113 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%114 = tt.dot %113, %112, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%113 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%114 = tt.dot %112, %113, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%115 = tt.addptr %40, %87 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
%116 = tt.load %115 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
%117 = triton_gpu.convert_layout %116 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
@@ -131,17 +132,17 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%119 = tt.broadcast %118 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%120 = arith.subf %cst, %119 : tensor<128x128xf32, #mma0>
%121 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
%122 = triton_gpu.convert_layout %121 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%123 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%124 = tt.dot %123, %122, %120 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%122 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%123 = triton_gpu.convert_layout %121 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%124 = tt.dot %122, %123, %120 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%125 = arith.mulf %106, %124 : tensor<128x128xf32, #mma0>
%126 = arith.mulf %125, %39 : tensor<128x128xf32, #mma0>
%127 = arith.truncf %126 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
%129 = tt.trans %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
%130 = triton_gpu.convert_layout %89 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%131 = triton_gpu.convert_layout %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%132 = tt.dot %131, %130, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%130 = triton_gpu.convert_layout %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%131 = triton_gpu.convert_layout %89 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%132 = tt.dot %130, %131, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%133 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
%134 = triton_gpu.convert_layout %133 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
%135 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>

View File

@@ -3,18 +3,18 @@
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}>
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 8 : i32} {
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
%cst_1 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c128_i32 = arith.constant 128 : i32
%c128 = arith.constant 128 : index
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
%cst_1 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
%cst_10 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.divsi %0, %arg22 : i32
%2 = arith.remsi %0, %arg22 : i32
@@ -82,13 +82,13 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%63 = tt.addptr %30, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%64 = tt.load %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%65 = arith.index_cast %47 : i32 to index
%66 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1>
%67 = tt.trans %66 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared0>
%66 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
%67 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
%68 = arith.addi %49, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
%69 = tt.expand_dims %68 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
%70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
%71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1>
%72 = tt.trans %71 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared0>
%71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
%72 = tt.trans %71 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
%73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
%74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
%75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2>
@@ -100,69 +100,69 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0>
%88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%89 = arith.addi %87, %14 : tensor<128xi32, #blocked0>
%900 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%90 = triton_gpu.convert_layout %900 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1>
%92 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #shared1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%91 = triton_gpu.convert_layout %67 : (tensor<64x128xf16, #shared0>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%93 = tt.dot %92, %91, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%94 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%95 = tt.expand_dims %94 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
%96 = tt.broadcast %95 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
%97 = "triton_gpu.cmpi"(%96, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
%98 = "triton_gpu.select"(%97, %93, %cst_1) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%99 = tt.addptr %38, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
%100 = tt.load %99 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
%101 = arith.mulf %98, %39 : tensor<128x128xf32, #mma0>
%102 = triton_gpu.convert_layout %100 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%103 = tt.expand_dims %102 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
%104 = tt.broadcast %103 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%105 = arith.subf %101, %104 : tensor<128x128xf32, #mma0>
%106 = math.exp %105 : tensor<128x128xf32, #mma0>
%1070 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%107 = triton_gpu.convert_layout %1070 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1>
%108 = arith.truncf %106 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%109 = triton_gpu.convert_layout %108 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1>
%110 = tt.trans %109 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0>
%111 = triton_gpu.convert_layout %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%112 = triton_gpu.convert_layout %107 : (tensor<128x64xf16, #shared1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%113 = tt.dot %111, %112, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%114 = tt.addptr %40, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
%115 = tt.load %114 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
%116 = triton_gpu.convert_layout %115 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%117 = tt.expand_dims %116 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
%118 = tt.broadcast %117 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%119 = arith.subf %cst_0, %118 : tensor<128x128xf32, #mma0>
%120 = triton_gpu.convert_layout %107 : (tensor<128x64xf16, #shared1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%121 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared0>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%122 = tt.dot %120, %121, %119 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%123 = arith.mulf %106, %122 : tensor<128x128xf32, #mma0>
%124 = arith.mulf %123, %39 : tensor<128x128xf32, #mma0>
%125 = arith.truncf %124 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%126 = triton_gpu.convert_layout %125 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1>
%127 = tt.trans %126 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0>
%128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%129 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #shared1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%130 = tt.dot %128, %129, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%131 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
%133 = triton_gpu.convert_layout %126 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%134 = triton_gpu.convert_layout %66 : (tensor<128x64xf16, #shared1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%135 = tt.dot %133, %134, %cst_10 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
%140 = arith.addf %136, %131 : tensor<128x64xf32, #blocked2>
tt.store %arg29, %140: tensor<128x64xf32, #blocked2>
%137 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
%138 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%139 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
scf.yield %113, %130, %137, %138, %139 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
%90 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%91 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
%92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%93 = triton_gpu.convert_layout %67 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%94 = tt.dot %92, %93, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%95 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
%97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
%98 = "triton_gpu.cmpi"(%97, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
%99 = "triton_gpu.select"(%98, %94, %cst_1) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%100 = tt.addptr %38, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
%101 = tt.load %100 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
%102 = arith.mulf %99, %39 : tensor<128x128xf32, #mma0>
%103 = triton_gpu.convert_layout %101 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%104 = tt.expand_dims %103 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
%105 = tt.broadcast %104 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%106 = arith.subf %102, %105 : tensor<128x128xf32, #mma0>
%107 = math.exp %106 : tensor<128x128xf32, #mma0>
%108 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%109 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
%110 = arith.truncf %107 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%111 = triton_gpu.convert_layout %110 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
%112 = tt.trans %111 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
%113 = triton_gpu.convert_layout %112 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%114 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%115 = tt.dot %113, %114, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%116 = tt.addptr %40, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
%117 = tt.load %116 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
%118 = triton_gpu.convert_layout %117 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%119 = tt.expand_dims %118 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
%120 = tt.broadcast %119 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%121 = arith.subf %cst_0, %120 : tensor<128x128xf32, #mma0>
%122 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%123 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%124 = tt.dot %122, %123, %121 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%125 = arith.mulf %107, %124 : tensor<128x128xf32, #mma0>
%126 = arith.mulf %125, %39 : tensor<128x128xf32, #mma0>
%127 = arith.truncf %126 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
%129 = tt.trans %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
%130 = triton_gpu.convert_layout %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%131 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%132 = tt.dot %130, %131, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%133 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
%134 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%135 = triton_gpu.convert_layout %66 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%136 = tt.dot %134, %135, %cst_2 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%137 = triton_gpu.convert_layout %136 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
%138 = arith.addf %137, %133 : tensor<128x64xf32, #blocked2>
tt.store %arg29, %138 : tensor<128x64xf32, #blocked2>
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
%140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
scf.yield %115, %132, %139, %140, %141 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
}
%82 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%80 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%81 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
%83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
tt.store %82, %83 : tensor<128x64xf16, #blocked1>
%84 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%80 = arith.truncf %79#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
%85 = triton_gpu.convert_layout %80 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
tt.store %84, %85 : tensor<128x64xf16, #blocked1>
%82 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
tt.store %80, %82 : tensor<128x64xf16, #blocked1>
%83 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%84 = arith.truncf %79#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
%85 = triton_gpu.convert_layout %84 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
tt.store %83, %85 : tensor<128x64xf16, #blocked1>
}
return
}

View File

@@ -7,13 +7,14 @@
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 8 : i32} {
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
%cst_1 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c128_i32 = arith.constant 128 : i32
%c128 = arith.constant 128 : index
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
%cst_0 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.divsi %0, %arg22 : i32
%2 = arith.remsi %0, %arg22 : i32
@@ -102,12 +103,12 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%90 = triton_gpu.convert_layout %89 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%91 = tt.trans %60 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
%92 = triton_gpu.convert_layout %91 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%93 = tt.dot %90, %92, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%93 = tt.dot %90, %92, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%94 = arith.addi %86, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%95 = tt.expand_dims %94 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
%96 = tt.broadcast %95 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
%97 = "triton_gpu.cmpi"(%96, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
%98 = "triton_gpu.select"(%97, %93, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%98 = "triton_gpu.select"(%97, %93, %cst_1) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%99 = tt.addptr %38, %87 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
%100 = tt.load %99 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
%101 = triton_gpu.convert_layout %100 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
@@ -117,23 +118,23 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%105 = arith.subf %102, %104 : tensor<128x128xf32, #mma0>
%106 = math.exp %105 : tensor<128x128xf32, #mma0>
%107 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%108 = arith.truncf %106 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%109 = triton_gpu.convert_layout %108 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
%110 = triton_gpu.convert_layout %107 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
%111 = triton_gpu.convert_layout %110 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%112 = tt.trans %109 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
%113 = triton_gpu.convert_layout %112 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%114 = tt.dot %113, %111, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%108 = triton_gpu.convert_layout %107 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
%109 = arith.truncf %106 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
%111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
%112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%113 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%114 = tt.dot %112, %113, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%115 = tt.addptr %40, %87 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
%116 = tt.load %115 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
%117 = triton_gpu.convert_layout %116 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%118 = tt.expand_dims %117 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
%119 = tt.broadcast %118 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%120 = arith.subf %cst_1, %119 : tensor<128x128xf32, #mma0>
%120 = arith.subf %cst_0, %119 : tensor<128x128xf32, #mma0>
%121 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
%122 = triton_gpu.convert_layout %121 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%123 = triton_gpu.convert_layout %110 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%124 = tt.dot %123, %122, %120 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%122 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%123 = triton_gpu.convert_layout %121 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%124 = tt.dot %122, %123, %120 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%125 = arith.mulf %106, %124 : tensor<128x128xf32, #mma0>
%126 = arith.mulf %125, %39 : tensor<128x128xf32, #mma0>
%127 = arith.truncf %126 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
@@ -144,9 +145,9 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%132 = tt.dot %130, %131, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%133 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
%134 = triton_gpu.convert_layout %133 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
%135 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%136 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%137 = tt.dot %136, %135, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%135 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%136 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%137 = tt.dot %135, %136, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
tt.store %arg29, %138 : tensor<128x64xf32, #blocked2>
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>

View File

@@ -191,7 +191,7 @@ def _bwd_kernel(
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
_bwd_kernel = triton.compile("./slow.ttgir", num_warps=8)
# _bwd_kernel = triton.compile("./slow.ttgir", num_warps=8)
# _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8)
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
# _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432)
@@ -260,36 +260,36 @@ class _attention(torch.autograd.Function):
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],1,1)](
q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
o.data_ptr(), do_scaled.data_ptr(),
dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
l.data_ptr(), m.data_ptr(),
delta.data_ptr(),
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0]
)
# pgm = _bwd_kernel[(ctx.grid[1],)](
# q, k, v, ctx.sm_scale,
# o, do_scaled,
# dq, dk, dv,
# l, m,
# delta,
# q.stride(0), q.stride(1), q.stride(2), q.stride(3),
# k.stride(0), k.stride(1), k.stride(2), k.stride(3),
# v.stride(0), v.stride(1), v.stride(2), v.stride(3),
# _bwd_kernel[(ctx.grid[1],1,1)](
# q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
# o.data_ptr(), do_scaled.data_ptr(),
# dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
# l.data_ptr(), m.data_ptr(),
# delta.data_ptr(),
# q.stride(0), q.stride(1), q.stride(2),
# k.stride(0), k.stride(1), k.stride(2),
# v.stride(0), v.stride(1), v.stride(2),
# q.shape[0], q.shape[1], q.shape[2],
# ctx.grid[0],
# BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
# BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
# num_stages=1,
# ctx.grid[0]
# )
# print(pgm.asm["ttgir"])
# exit()
pgm = _bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
)
print(pgm.asm["ttgir"])
exit()
return dq, dk, dv, None