|
|
|
@@ -9,17 +9,16 @@
|
|
|
|
|
#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}>
|
|
|
|
|
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}>
|
|
|
|
|
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
|
|
|
|
|
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
|
|
|
|
#shared2 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
|
|
|
|
|
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
|
|
|
|
|
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|
|
|
|
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
|
|
|
|
|
%c0 = arith.constant 0 : index
|
|
|
|
|
%c1 = arith.constant 1 : index
|
|
|
|
|
%c128_i32 = arith.constant 128 : i32
|
|
|
|
|
%c128 = arith.constant 128 : index
|
|
|
|
|
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
|
|
|
|
|
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
|
|
|
|
|
%cst_0 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
|
|
|
|
|
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
|
|
|
|
|
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
|
|
|
|
|
%c128 = arith.constant 128 : index
|
|
|
|
|
%c128_i32 = arith.constant 128 : i32
|
|
|
|
|
%c1 = arith.constant 1 : index
|
|
|
|
|
%c0 = arith.constant 0 : index
|
|
|
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
|
|
|
%1 = arith.divsi %0, %arg22 : i32
|
|
|
|
|
%2 = arith.remsi %0, %arg22 : i32
|
|
|
|
@@ -82,93 +81,92 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|
|
|
|
%58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
%59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
|
|
|
%60 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
|
|
|
%61 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1>
|
|
|
|
|
%62 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1>
|
|
|
|
|
%63 = tt.broadcast %62 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
|
|
|
%64 = arith.addi %63, %24 : tensor<128x64xi32, #blocked1>
|
|
|
|
|
%65 = tt.addptr %30, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
%66 = tt.load %65 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
|
|
|
%67 = triton_gpu.convert_layout %66 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1>
|
|
|
|
|
%68 = arith.index_cast %47 : i32 to index
|
|
|
|
|
%69 = tt.trans %61 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared2>
|
|
|
|
|
%70 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
|
|
|
|
%71 = tt.expand_dims %70 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
|
|
|
|
|
%72 = tt.broadcast %71 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
|
|
|
|
%73 = tt.trans %67 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared2>
|
|
|
|
|
%74 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
|
|
|
|
|
%75 = tt.broadcast %74 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
|
|
|
|
%76 = arith.addi %75, %26 : tensor<128x64xi32, #blocked2>
|
|
|
|
|
%77 = tt.addptr %32, %76 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
|
|
|
|
%78 = tt.addptr %27, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
%79 = tt.addptr %31, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
%80:5 = scf.for %arg26 = %68 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %77, %arg30 = %78, %arg31 = %79) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
|
|
|
|
|
%87 = arith.index_cast %arg26 : index to i32
|
|
|
|
|
%88 = tt.splat %87 : (i32) -> tensor<128xi32, #blocked0>
|
|
|
|
|
%89 = tt.splat %87 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
|
|
|
%90 = arith.addi %88, %14 : tensor<128xi32, #blocked0>
|
|
|
|
|
%91 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
|
|
|
%92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
|
|
|
%93 = triton_gpu.convert_layout %69 : (tensor<64x128xf16, #shared2>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
|
|
|
|
%94 = triton_gpu.convert_layout %92 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
|
|
|
|
%95 = tt.dot %94, %93, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
|
|
|
|
%96 = arith.addi %89, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
|
|
|
%97 = tt.expand_dims %96 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
|
|
|
|
|
%98 = tt.broadcast %97 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
|
|
|
|
%99 = "triton_gpu.cmpi"(%98, %72) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
|
|
|
|
|
%100 = "triton_gpu.select"(%99, %95, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
|
|
|
%101 = tt.addptr %38, %90 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
|
|
|
|
%102 = tt.load %101 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
|
|
|
|
%103 = triton_gpu.convert_layout %102 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
|
|
|
%104 = arith.mulf %100, %39 : tensor<128x128xf32, #mma0>
|
|
|
|
|
%105 = tt.expand_dims %103 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
|
|
|
|
%106 = tt.broadcast %105 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
|
|
|
%107 = arith.subf %104, %106 : tensor<128x128xf32, #mma0>
|
|
|
|
|
%108 = math.exp %107 : tensor<128x128xf32, #mma0>
|
|
|
|
|
%109 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
|
|
|
%110 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
|
|
|
%111 = arith.truncf %108 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
|
|
|
|
%112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1>
|
|
|
|
|
%113 = tt.trans %112 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared2>
|
|
|
|
|
%114 = triton_gpu.convert_layout %113 : (tensor<128x128xf16, #shared2>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
|
|
|
%115 = triton_gpu.convert_layout %110 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
|
|
|
%116 = tt.dot %114, %115, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
|
|
|
%117 = tt.addptr %40, %90 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
|
|
|
|
%118 = tt.load %117 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
|
|
|
|
%119 = triton_gpu.convert_layout %118 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
|
|
|
%120 = tt.expand_dims %119 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
|
|
|
|
%121 = tt.broadcast %120 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
|
|
|
%122 = arith.subf %cst_1, %121 : tensor<128x128xf32, #mma0>
|
|
|
|
|
%123 = triton_gpu.convert_layout %110 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
|
|
|
|
%124 = triton_gpu.convert_layout %73 : (tensor<64x128xf16, #shared2>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
|
|
|
|
%125 = tt.dot %123, %124, %122 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
|
|
|
|
%126 = arith.mulf %108, %125 : tensor<128x128xf32, #mma0>
|
|
|
|
|
%127 = arith.mulf %126, %39 : tensor<128x128xf32, #mma0>
|
|
|
|
|
%128 = arith.truncf %127 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
|
|
|
|
%129 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1>
|
|
|
|
|
%130 = tt.trans %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared2>
|
|
|
|
|
%131 = triton_gpu.convert_layout %92 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
|
|
|
%132 = triton_gpu.convert_layout %130 : (tensor<128x128xf16, #shared2>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
|
|
|
%133 = tt.dot %132, %131, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
|
|
|
%134 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
|
|
|
|
|
%135 = triton_gpu.convert_layout %134 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
|
|
|
|
|
%136 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
|
|
|
%137 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
|
|
|
%138 = tt.dot %136, %137, %135 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
|
|
|
%139 = triton_gpu.convert_layout %138 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
|
|
|
|
|
tt.store %arg29, %139 : tensor<128x64xf32, #blocked2>
|
|
|
|
|
%140 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
|
|
|
|
%141 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
%142 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
scf.yield %116, %133, %140, %141, %142 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
|
|
|
%61 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1>
|
|
|
|
|
%62 = tt.broadcast %61 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
|
|
|
%63 = arith.addi %62, %24 : tensor<128x64xi32, #blocked1>
|
|
|
|
|
%64 = tt.addptr %30, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
%65 = tt.load %64 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
|
|
|
%66 = triton_gpu.convert_layout %65 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
|
|
|
%67 = arith.index_cast %47 : i32 to index
|
|
|
|
|
%68 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
|
|
|
|
%69 = tt.expand_dims %68 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
|
|
|
|
|
%70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
|
|
|
|
%71 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
|
|
|
|
|
%72 = tt.broadcast %71 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
|
|
|
|
%73 = arith.addi %72, %26 : tensor<128x64xi32, #blocked2>
|
|
|
|
|
%74 = tt.addptr %32, %73 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
|
|
|
|
%75 = tt.addptr %27, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
%76 = tt.addptr %31, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
%77:5 = scf.for %arg26 = %67 to %37 step %c128 iter_args(%arg27 = %cst_1, %arg28 = %cst_1, %arg29 = %74, %arg30 = %75, %arg31 = %76) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
|
|
|
|
|
%84 = arith.index_cast %arg26 : index to i32
|
|
|
|
|
%85 = tt.splat %84 : (i32) -> tensor<128xi32, #blocked0>
|
|
|
|
|
%86 = tt.splat %84 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
|
|
|
%87 = arith.addi %85, %14 : tensor<128xi32, #blocked0>
|
|
|
|
|
%88 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
|
|
|
%89 = triton_gpu.convert_layout %88 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
|
|
|
%90 = tt.trans %60 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
|
|
|
|
%91 = triton_gpu.convert_layout %89 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
|
|
|
|
%92 = triton_gpu.convert_layout %90 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
|
|
|
|
%93 = tt.dot %91, %92, %cst {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
|
|
|
|
%94 = arith.addi %86, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
|
|
|
%95 = tt.expand_dims %94 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
|
|
|
|
|
%96 = tt.broadcast %95 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
|
|
|
|
%97 = "triton_gpu.cmpi"(%96, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
|
|
|
|
|
%98 = "triton_gpu.select"(%97, %93, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
|
|
|
%99 = tt.addptr %38, %87 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
|
|
|
|
%100 = tt.load %99 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
|
|
|
|
%101 = triton_gpu.convert_layout %100 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
|
|
|
%102 = arith.mulf %98, %39 : tensor<128x128xf32, #mma0>
|
|
|
|
|
%103 = tt.expand_dims %101 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
|
|
|
|
%104 = tt.broadcast %103 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
|
|
|
%105 = arith.subf %102, %104 : tensor<128x128xf32, #mma0>
|
|
|
|
|
%106 = math.exp %105 : tensor<128x128xf32, #mma0>
|
|
|
|
|
%107 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
|
|
|
%108 = triton_gpu.convert_layout %107 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
|
|
|
%109 = arith.truncf %106 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
|
|
|
|
%110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
|
|
|
|
%111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
|
|
|
|
%112 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
|
|
|
%113 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
|
|
|
%114 = tt.dot %113, %112, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
|
|
|
%115 = tt.addptr %40, %87 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
|
|
|
|
%116 = tt.load %115 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
|
|
|
|
%117 = triton_gpu.convert_layout %116 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
|
|
|
%118 = tt.expand_dims %117 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
|
|
|
|
%119 = tt.broadcast %118 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
|
|
|
%120 = arith.subf %cst, %119 : tensor<128x128xf32, #mma0>
|
|
|
|
|
%121 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
|
|
|
|
%122 = triton_gpu.convert_layout %121 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
|
|
|
|
%123 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
|
|
|
|
%124 = tt.dot %123, %122, %120 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
|
|
|
|
%125 = arith.mulf %106, %124 : tensor<128x128xf32, #mma0>
|
|
|
|
|
%126 = arith.mulf %125, %39 : tensor<128x128xf32, #mma0>
|
|
|
|
|
%127 = arith.truncf %126 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
|
|
|
|
%128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
|
|
|
|
%129 = tt.trans %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
|
|
|
|
%130 = triton_gpu.convert_layout %89 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
|
|
|
%131 = triton_gpu.convert_layout %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
|
|
|
%132 = tt.dot %131, %130, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
|
|
|
%133 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
|
|
|
|
|
%134 = triton_gpu.convert_layout %133 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
|
|
|
|
|
%135 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
|
|
|
%136 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
|
|
|
%137 = tt.dot %135, %136, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
|
|
|
%138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
|
|
|
|
|
tt.store %arg29, %138 : tensor<128x64xf32, #blocked2>
|
|
|
|
|
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
|
|
|
|
%140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
%141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
scf.yield %114, %132, %139, %140, %141 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
|
|
|
}
|
|
|
|
|
%81 = arith.truncf %80#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
|
|
|
|
%82 = tt.addptr %44, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
%78 = arith.truncf %77#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
|
|
|
|
%79 = tt.addptr %44, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
%80 = triton_gpu.convert_layout %78 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
|
|
|
|
tt.store %79, %80 : tensor<128x64xf16, #blocked1>
|
|
|
|
|
%81 = arith.truncf %77#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
|
|
|
|
%82 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
%83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
|
|
|
|
tt.store %82, %83 : tensor<128x64xf16, #blocked1>
|
|
|
|
|
%84 = arith.truncf %80#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
|
|
|
|
%85 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
|
|
|
%86 = triton_gpu.convert_layout %84 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
|
|
|
|
tt.store %85, %86 : tensor<128x64xf16, #blocked1>
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|