cleanup
This commit is contained in:
@@ -1,169 +0,0 @@
|
|||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
|
||||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
|
|
||||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
|
||||||
#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}>
|
|
||||||
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}>
|
|
||||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
|
|
||||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
|
|
||||||
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|
||||||
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
|
|
||||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
|
|
||||||
%cst_0 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
|
|
||||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
|
|
||||||
%cst_10 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
|
|
||||||
%c128 = arith.constant 128 : index
|
|
||||||
%c128_i32 = arith.constant 128 : i32
|
|
||||||
%c1 = arith.constant 1 : index
|
|
||||||
%c0 = arith.constant 0 : index
|
|
||||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
||||||
%1 = arith.divsi %0, %arg22 : i32
|
|
||||||
%2 = arith.remsi %0, %arg22 : i32
|
|
||||||
%3 = arith.muli %1, %arg12 : i32
|
|
||||||
%4 = arith.muli %2, %arg13 : i32
|
|
||||||
%5 = arith.addi %3, %4 : i32
|
|
||||||
%6 = tt.addptr %arg0, %5 : !tt.ptr<f16>, i32
|
|
||||||
%7 = tt.addptr %arg1, %5 : !tt.ptr<f16>, i32
|
|
||||||
%8 = tt.addptr %arg2, %5 : !tt.ptr<f16>, i32
|
|
||||||
%9 = tt.addptr %arg5, %5 : !tt.ptr<f16>, i32
|
|
||||||
%10 = tt.addptr %arg6, %5 : !tt.ptr<f32>, i32
|
|
||||||
%11 = tt.addptr %arg7, %5 : !tt.ptr<f16>, i32
|
|
||||||
%12 = tt.addptr %arg8, %5 : !tt.ptr<f16>, i32
|
|
||||||
%13 = arith.index_cast %arg24 : i32 to index
|
|
||||||
%14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0>
|
|
||||||
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
|
||||||
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
|
||||||
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%19 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked2>
|
|
||||||
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
|
||||||
%22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
|
||||||
%23 = tt.expand_dims %21 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
|
||||||
%24 = tt.broadcast %23 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%25 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
|
||||||
%26 = tt.broadcast %25 : (tensor<1x64xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
|
||||||
%27 = tt.splat %6 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%28 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%29 = tt.splat %7 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%30 = tt.splat %8 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%31 = tt.splat %9 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%32 = tt.splat %10 : (!tt.ptr<f32>) -> tensor<128x64x!tt.ptr<f32>, #blocked2>
|
|
||||||
%33 = arith.muli %0, %arg23 : i32
|
|
||||||
%34 = tt.addptr %arg11, %33 : !tt.ptr<f32>, i32
|
|
||||||
%35 = tt.addptr %arg10, %33 : !tt.ptr<f32>, i32
|
|
||||||
%36 = arith.muli %arg24, %c128_i32 : i32
|
|
||||||
%37 = arith.index_cast %36 : i32 to index
|
|
||||||
%38 = tt.splat %35 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
|
||||||
%39 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma0>
|
|
||||||
%40 = tt.splat %34 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
|
||||||
%41 = arith.muli %arg14, %c128_i32 : i32
|
|
||||||
%42 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%43 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked2>
|
|
||||||
%44 = tt.splat %12 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%45 = tt.splat %11 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
scf.for %arg25 = %c0 to %13 step %c1 {
|
|
||||||
%46 = arith.index_cast %arg25 : index to i32
|
|
||||||
%47 = arith.muli %46, %c128_i32 : i32
|
|
||||||
%48 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%49 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
|
||||||
%50 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
|
||||||
%51 = arith.addi %48, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%52 = arith.addi %49, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
|
||||||
%53 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%54 = tt.expand_dims %52 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2>
|
|
||||||
%55 = arith.muli %53, %28 : tensor<128x1xi32, #blocked1>
|
|
||||||
%56 = tt.broadcast %55 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%57 = arith.addi %56, %24 : tensor<128x64xi32, #blocked1>
|
|
||||||
%58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%60 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
||||||
%61 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1>
|
|
||||||
%62 = tt.broadcast %61 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%63 = arith.addi %62, %24 : tensor<128x64xi32, #blocked1>
|
|
||||||
%64 = tt.addptr %30, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%65 = tt.load %64 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%66 = triton_gpu.convert_layout %65 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
||||||
%67 = arith.index_cast %47 : i32 to index
|
|
||||||
%68 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
|
||||||
%69 = tt.expand_dims %68 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
|
|
||||||
%70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
|
||||||
%71 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
|
|
||||||
%72 = tt.broadcast %71 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
|
||||||
%73 = arith.addi %72, %26 : tensor<128x64xi32, #blocked2>
|
|
||||||
%74 = tt.addptr %32, %73 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
|
||||||
%75 = tt.addptr %27, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%76 = tt.addptr %31, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%77:5 = scf.for %arg26 = %67 to %37 step %c128 iter_args(%arg27 = %cst_1, %arg28 = %cst_1, %arg29 = %74, %arg30 = %75, %arg31 = %76) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
|
|
||||||
%84 = arith.index_cast %arg26 : index to i32
|
|
||||||
%85 = tt.splat %84 : (i32) -> tensor<128xi32, #blocked0>
|
|
||||||
%86 = tt.splat %84 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%87 = arith.addi %85, %14 : tensor<128xi32, #blocked0>
|
|
||||||
%88 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%89 = triton_gpu.convert_layout %88 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
||||||
%90 = tt.trans %60 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
|
||||||
%91 = triton_gpu.convert_layout %89 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
|
||||||
%92 = triton_gpu.convert_layout %90 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
|
||||||
%93 = tt.dot %91, %92, %cst {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
|
||||||
%94 = arith.addi %86, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%95 = tt.expand_dims %94 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
|
|
||||||
%96 = tt.broadcast %95 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
|
||||||
%97 = "triton_gpu.cmpi"(%96, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
|
|
||||||
%98 = "triton_gpu.select"(%97, %93, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
||||||
%99 = tt.addptr %38, %87 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
|
||||||
%100 = tt.load %99 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
|
||||||
%101 = triton_gpu.convert_layout %100 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%102 = arith.mulf %98, %39 : tensor<128x128xf32, #mma0>
|
|
||||||
%103 = tt.expand_dims %101 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
|
||||||
%104 = tt.broadcast %103 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
||||||
%105 = arith.subf %102, %104 : tensor<128x128xf32, #mma0>
|
|
||||||
%106 = math.exp %105 : tensor<128x128xf32, #mma0>
|
|
||||||
%107 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%108 = triton_gpu.convert_layout %107 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
||||||
%109 = arith.truncf %106 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
|
||||||
%110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
|
||||||
%111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
|
||||||
%112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
||||||
%113 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
||||||
%114 = tt.dot %112, %113, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
||||||
%115 = tt.addptr %40, %87 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
|
||||||
%116 = tt.load %115 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
|
||||||
%117 = triton_gpu.convert_layout %116 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%118 = tt.expand_dims %117 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
|
||||||
%119 = tt.broadcast %118 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
||||||
%120 = arith.subf %cst, %119 : tensor<128x128xf32, #mma0>
|
|
||||||
%121 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
|
||||||
%122 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
|
||||||
%123 = triton_gpu.convert_layout %121 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
|
||||||
%124 = tt.dot %122, %123, %120 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
|
||||||
%125 = arith.mulf %106, %124 : tensor<128x128xf32, #mma0>
|
|
||||||
%126 = arith.mulf %125, %39 : tensor<128x128xf32, #mma0>
|
|
||||||
%127 = arith.truncf %126 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
|
||||||
%128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
|
||||||
%129 = tt.trans %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
|
||||||
%130 = triton_gpu.convert_layout %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
||||||
%131 = triton_gpu.convert_layout %89 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
||||||
%132 = tt.dot %130, %131, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
||||||
%133 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
|
|
||||||
%134 = triton_gpu.convert_layout %133 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
|
|
||||||
%135 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
||||||
%136 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
||||||
%137 = tt.dot %135, %136, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
||||||
%138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
|
|
||||||
tt.store %arg29, %138 : tensor<128x64xf32, #blocked2>
|
|
||||||
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
|
||||||
%140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
scf.yield %114, %132, %139, %140, %141 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
}
|
|
||||||
%78 = arith.truncf %77#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
|
||||||
%79 = tt.addptr %44, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%80 = triton_gpu.convert_layout %78 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
|
||||||
tt.store %79, %80 : tensor<128x64xf16, #blocked1>
|
|
||||||
%81 = arith.truncf %77#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
|
||||||
%82 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
|
||||||
tt.store %82, %83 : tensor<128x64xf16, #blocked1>
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
2764
python/bwd.ptx
2764
python/bwd.ptx
File diff suppressed because it is too large
Load Diff
169
python/bwd.ttgir
169
python/bwd.ttgir
@@ -1,169 +0,0 @@
|
|||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
|
||||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
|
|
||||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
|
||||||
#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}>
|
|
||||||
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}>
|
|
||||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
|
|
||||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
|
|
||||||
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|
||||||
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
|
|
||||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
|
|
||||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
|
|
||||||
%cst_1 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
|
|
||||||
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
|
|
||||||
%c0 = arith.constant 0 : index
|
|
||||||
%c1 = arith.constant 1 : index
|
|
||||||
%c128_i32 = arith.constant 128 : i32
|
|
||||||
%c128 = arith.constant 128 : index
|
|
||||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
||||||
%1 = arith.divsi %0, %arg22 : i32
|
|
||||||
%2 = arith.remsi %0, %arg22 : i32
|
|
||||||
%3 = arith.muli %1, %arg12 : i32
|
|
||||||
%4 = arith.muli %2, %arg13 : i32
|
|
||||||
%5 = arith.addi %3, %4 : i32
|
|
||||||
%6 = tt.addptr %arg0, %5 : !tt.ptr<f16>, i32
|
|
||||||
%7 = tt.addptr %arg1, %5 : !tt.ptr<f16>, i32
|
|
||||||
%8 = tt.addptr %arg2, %5 : !tt.ptr<f16>, i32
|
|
||||||
%9 = tt.addptr %arg5, %5 : !tt.ptr<f16>, i32
|
|
||||||
%10 = tt.addptr %arg6, %5 : !tt.ptr<f32>, i32
|
|
||||||
%11 = tt.addptr %arg7, %5 : !tt.ptr<f16>, i32
|
|
||||||
%12 = tt.addptr %arg8, %5 : !tt.ptr<f16>, i32
|
|
||||||
%13 = arith.index_cast %arg24 : i32 to index
|
|
||||||
%14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0>
|
|
||||||
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
|
||||||
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
|
||||||
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%19 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked2>
|
|
||||||
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
|
||||||
%22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
|
||||||
%23 = tt.expand_dims %21 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
|
||||||
%24 = tt.broadcast %23 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%25 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
|
||||||
%26 = tt.broadcast %25 : (tensor<1x64xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
|
||||||
%27 = tt.splat %6 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%28 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%29 = tt.splat %7 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%30 = tt.splat %8 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%31 = tt.splat %9 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%32 = tt.splat %10 : (!tt.ptr<f32>) -> tensor<128x64x!tt.ptr<f32>, #blocked2>
|
|
||||||
%33 = arith.muli %0, %arg23 : i32
|
|
||||||
%34 = tt.addptr %arg11, %33 : !tt.ptr<f32>, i32
|
|
||||||
%35 = tt.addptr %arg10, %33 : !tt.ptr<f32>, i32
|
|
||||||
%36 = arith.muli %arg24, %c128_i32 : i32
|
|
||||||
%37 = arith.index_cast %36 : i32 to index
|
|
||||||
%38 = tt.splat %35 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
|
||||||
%39 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma0>
|
|
||||||
%40 = tt.splat %34 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
|
||||||
%41 = arith.muli %arg14, %c128_i32 : i32
|
|
||||||
%42 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%43 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked2>
|
|
||||||
%44 = tt.splat %12 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%45 = tt.splat %11 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
scf.for %arg25 = %c0 to %13 step %c1 {
|
|
||||||
%46 = arith.index_cast %arg25 : index to i32
|
|
||||||
%47 = arith.muli %46, %c128_i32 : i32
|
|
||||||
%48 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%49 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
|
||||||
%50 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
|
||||||
%51 = arith.addi %48, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%52 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
|
||||||
%53 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%54 = tt.expand_dims %52 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2>
|
|
||||||
%55 = arith.muli %53, %28 : tensor<128x1xi32, #blocked1>
|
|
||||||
%56 = tt.broadcast %55 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%57 = arith.addi %56, %24 : tensor<128x64xi32, #blocked1>
|
|
||||||
%58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%60 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1>
|
|
||||||
%61 = tt.broadcast %60 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%62 = arith.addi %61, %24 : tensor<128x64xi32, #blocked1>
|
|
||||||
%63 = tt.addptr %30, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%64 = tt.load %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%65 = arith.index_cast %47 : i32 to index
|
|
||||||
%66 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
||||||
%67 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
|
||||||
%68 = arith.addi %49, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
|
||||||
%69 = tt.expand_dims %68 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
|
|
||||||
%70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
|
||||||
%71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
||||||
%72 = tt.trans %71 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
|
||||||
%73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
|
|
||||||
%74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
|
||||||
%75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2>
|
|
||||||
%76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
|
||||||
%77 = tt.addptr %27, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%78 = tt.addptr %31, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%79:5 = scf.for %arg26 = %65 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
|
|
||||||
%86 = arith.index_cast %arg26 : index to i32
|
|
||||||
%87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0>
|
|
||||||
%88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%89 = arith.addi %87, %14 : tensor<128xi32, #blocked0>
|
|
||||||
%90 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%91 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
||||||
%92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
|
||||||
%93 = triton_gpu.convert_layout %67 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
|
||||||
%94 = tt.dot %92, %93, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
|
||||||
%95 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
|
|
||||||
%97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
|
||||||
%98 = "triton_gpu.cmpi"(%97, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
|
|
||||||
%99 = "triton_gpu.select"(%98, %94, %cst_1) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
||||||
%100 = tt.addptr %38, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
|
||||||
%101 = tt.load %100 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
|
||||||
%102 = arith.mulf %99, %39 : tensor<128x128xf32, #mma0>
|
|
||||||
%103 = triton_gpu.convert_layout %101 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%104 = tt.expand_dims %103 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
|
||||||
%105 = tt.broadcast %104 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
||||||
%106 = arith.subf %102, %105 : tensor<128x128xf32, #mma0>
|
|
||||||
%107 = math.exp %106 : tensor<128x128xf32, #mma0>
|
|
||||||
%108 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%109 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
||||||
%110 = arith.truncf %107 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
|
||||||
%111 = triton_gpu.convert_layout %110 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
|
||||||
%112 = tt.trans %111 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
|
||||||
%113 = triton_gpu.convert_layout %112 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
||||||
%114 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
||||||
%115 = tt.dot %113, %114, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
||||||
%116 = tt.addptr %40, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
|
||||||
%117 = tt.load %116 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
|
||||||
%118 = triton_gpu.convert_layout %117 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%119 = tt.expand_dims %118 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
|
||||||
%120 = tt.broadcast %119 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
||||||
%121 = arith.subf %cst_0, %120 : tensor<128x128xf32, #mma0>
|
|
||||||
%122 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
|
||||||
%123 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
|
||||||
%124 = tt.dot %122, %123, %121 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
|
||||||
%125 = arith.mulf %107, %124 : tensor<128x128xf32, #mma0>
|
|
||||||
%126 = arith.mulf %125, %39 : tensor<128x128xf32, #mma0>
|
|
||||||
%127 = arith.truncf %126 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
|
||||||
%128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
|
||||||
%129 = tt.trans %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
|
||||||
%130 = triton_gpu.convert_layout %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
||||||
%131 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
||||||
%132 = tt.dot %130, %131, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
||||||
%133 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
|
|
||||||
%134 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
||||||
%135 = triton_gpu.convert_layout %66 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
||||||
%136 = tt.dot %134, %135, %cst_2 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
||||||
%137 = triton_gpu.convert_layout %136 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
|
|
||||||
%138 = arith.addf %137, %133 : tensor<128x64xf32, #blocked2>
|
|
||||||
tt.store %arg29, %138 : tensor<128x64xf32, #blocked2>
|
|
||||||
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
|
||||||
%140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
scf.yield %115, %132, %139, %140, %141 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
}
|
|
||||||
%80 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%81 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
|
||||||
%82 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
|
||||||
tt.store %80, %82 : tensor<128x64xf16, #blocked1>
|
|
||||||
%83 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%84 = arith.truncf %79#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
|
||||||
%85 = triton_gpu.convert_layout %84 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
|
||||||
tt.store %83, %85 : tensor<128x64xf16, #blocked1>
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,159 +0,0 @@
|
|||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
|
||||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
|
||||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
|
|
||||||
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}>
|
|
||||||
#mma_s1 = #triton_gpu.slice<{dim = 1, parent = #mma}>
|
|
||||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
|
|
||||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
|
|
||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
||||||
func public @_fwd_kernel_0d1d2d34d5d6d7d8d9d10c11d12d13d14c15d16d17d18c19d20d21d22c2324d25d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32 {tt.divisibility = 16 : i32}) {
|
|
||||||
%c0_i32 = arith.constant 0 : i32
|
|
||||||
%cst = arith.constant dense<1.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
|
|
||||||
%cst_1 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma>
|
|
||||||
%cst_2 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
|
|
||||||
%cst_4 = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%c1_i32 = arith.constant 1 : i32
|
|
||||||
%c0 = arith.constant 0 : index
|
|
||||||
%c128 = arith.constant 128 : index
|
|
||||||
%c128_i32 = arith.constant 128 : i32
|
|
||||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
||||||
%1 = tt.get_program_id {axis = 1 : i32} : i32
|
|
||||||
%2 = arith.muli %0, %c128_i32 : i32
|
|
||||||
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0>
|
|
||||||
%4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%6 = tt.splat %2 : (i32) -> tensor<128xi32, #blocked0>
|
|
||||||
%7 = tt.splat %2 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%8 = tt.splat %2 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%9 = arith.addi %6, %3 : tensor<128xi32, #blocked0>
|
|
||||||
%10 = arith.muli %1, %arg8 : i32
|
|
||||||
%11 = arith.addi %7, %4 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%12 = arith.addi %8, %5 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%13 = tt.splat %arg9 : (i32) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%14 = tt.splat %10 : (i32) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
|
||||||
%16 = tt.expand_dims %15 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
|
||||||
%17 = tt.broadcast %16 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
|
||||||
%19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
|
|
||||||
%20 = tt.expand_dims %19 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>) -> tensor<1x128xi32, #mma>
|
|
||||||
%21 = tt.splat %arg12 : (i32) -> tensor<1x128xi32, #blocked2>
|
|
||||||
%22 = tt.splat %10 : (i32) -> tensor<1x128xi32, #blocked2>
|
|
||||||
%23 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
|
||||||
%24 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xi32, #blocked2>
|
|
||||||
%25 = tt.expand_dims %18 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2>
|
|
||||||
%26 = arith.muli %25, %21 : tensor<1x128xi32, #blocked2>
|
|
||||||
%27 = arith.addi %22, %26 : tensor<1x128xi32, #blocked2>
|
|
||||||
%28 = tt.broadcast %27 : (tensor<1x128xi32, #blocked2>) -> tensor<64x128xi32, #blocked2>
|
|
||||||
%29 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%30 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<64x128x!tt.ptr<f16>, #blocked2>
|
|
||||||
%31 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%32 = tt.expand_dims %11 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%33 = arith.muli %32, %13 : tensor<128x1xi32, #blocked1>
|
|
||||||
%34 = arith.addi %14, %33 : tensor<128x1xi32, #blocked1>
|
|
||||||
%35 = tt.broadcast %34 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%36 = arith.addi %35, %17 : tensor<128x64xi32, #blocked1>
|
|
||||||
%37 = tt.addptr %29, %36 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%38 = arith.addi %0, %c1_i32 : i32
|
|
||||||
%39 = arith.muli %38, %c128_i32 : i32
|
|
||||||
%40 = arith.index_cast %39 : i32 to index
|
|
||||||
%41 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma>
|
|
||||||
%42 = tt.expand_dims %12 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xi32, #mma>
|
|
||||||
%43 = tt.broadcast %42 : (tensor<128x1xi32, #mma>) -> tensor<128x128xi32, #mma>
|
|
||||||
%44 = arith.muli %arg12, %c128_i32 : i32
|
|
||||||
%45 = tt.splat %44 : (i32) -> tensor<64x128xi32, #blocked2>
|
|
||||||
%46 = arith.muli %arg15, %c128_i32 : i32
|
|
||||||
%47 = tt.splat %46 : (i32) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%48 = tt.broadcast %24 : (tensor<64x1xi32, #blocked2>) -> tensor<64x128xi32, #blocked2>
|
|
||||||
%49 = arith.addi %28, %48 : tensor<64x128xi32, #blocked2>
|
|
||||||
%50 = tt.addptr %30, %49 : tensor<64x128x!tt.ptr<f16>, #blocked2>, tensor<64x128xi32, #blocked2>
|
|
||||||
%51 = tt.expand_dims %4 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%52 = arith.muli %51, %13 : tensor<128x1xi32, #blocked1>
|
|
||||||
%53 = arith.addi %14, %52 : tensor<128x1xi32, #blocked1>
|
|
||||||
%54 = tt.broadcast %53 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%55 = arith.addi %54, %17 : tensor<128x64xi32, #blocked1>
|
|
||||||
%56 = tt.addptr %31, %55 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%79 = triton_gpu.alloc_tensor : tensor<1x128x64xf16, #shared0>
|
|
||||||
|
|
||||||
// TODO: Load should be transformed into `insert_slice_async + extract_slice` at the very end of the optimization pass so it benefits from LICM
|
|
||||||
%80 = triton_gpu.insert_slice_async %37, %79, %c0_i32 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64x!tt.ptr<f16>, #blocked1> -> tensor<1x128x64xf16, #shared0>
|
|
||||||
triton_gpu.async_wait {num = 0 : i32}
|
|
||||||
%81 = tensor.extract_slice %80[0, 0, 0] [1, 128, 64] [1, 1, 1] : tensor<1x128x64xf16, #shared0> to tensor<128x64xf16, #shared0>
|
|
||||||
%82 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
|
|
||||||
|
|
||||||
%57:5 = scf.for %arg22 = %c0 to %40 step %c128 iter_args(%arg23 = %cst_4, %arg24 = %cst_3, %arg25 = %cst_2, %arg26 = %50, %arg27 = %56) -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<64x128x!tt.ptr<f16>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
|
|
||||||
%78 = arith.index_cast %arg22 : index to i32
|
|
||||||
%83 = triton_gpu.alloc_tensor : tensor<1x64x128xf16, #shared1>
|
|
||||||
%84 = triton_gpu.insert_slice_async %arg26, %83, %c0_i32 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x128x!tt.ptr<f16>, #blocked2> -> tensor<1x64x128xf16, #shared1>
|
|
||||||
triton_gpu.async_wait {num = 0 : i32}
|
|
||||||
%85 = tensor.extract_slice %84[0, 0, 0] [1, 64, 128] [1, 1, 1] : tensor<1x64x128xf16, #shared1> to tensor<64x128xf16, #shared1>
|
|
||||||
%86 = triton_gpu.convert_layout %85 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>
|
|
||||||
%87 = tt.dot %82, %86, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x128xf32, #mma>
|
|
||||||
%88 = tt.splat %78 : (i32) -> tensor<1x128xi32, #mma>
|
|
||||||
%89 = arith.addi %88, %20 : tensor<1x128xi32, #mma>
|
|
||||||
%90 = tt.broadcast %89 : (tensor<1x128xi32, #mma>) -> tensor<128x128xi32, #mma>
|
|
||||||
%91 = arith.mulf %87, %41 : tensor<128x128xf32, #mma>
|
|
||||||
%92 = "triton_gpu.cmpi"(%43, %90) {predicate = 5 : i64} : (tensor<128x128xi32, #mma>, tensor<128x128xi32, #mma>) -> tensor<128x128xi1, #mma>
|
|
||||||
%93 = "triton_gpu.select"(%92, %91, %cst_1) : (tensor<128x128xi1, #mma>, tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma>) -> tensor<128x128xf32, #mma>
|
|
||||||
%94 = tt.reduce %93 {axis = 1 : i32, redOp = 12 : i32} : tensor<128x128xf32, #mma> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%95 = "triton_gpu.cmpf"(%94, %arg25) {predicate = 2 : i64} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%96 = "triton_gpu.select"(%95, %94, %arg25) : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%97 = tt.expand_dims %96 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
|
|
||||||
%98 = tt.broadcast %97 : (tensor<128x1xf32, #mma>) -> tensor<128x128xf32, #mma>
|
|
||||||
%99 = arith.subf %93, %98 : tensor<128x128xf32, #mma>
|
|
||||||
%100 = math.exp %99 : tensor<128x128xf32, #mma>
|
|
||||||
%101 = tt.reduce %100 {axis = 1 : i32, redOp = 2 : i32} : tensor<128x128xf32, #mma> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%102 = arith.subf %arg25, %96 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%103 = math.exp %102 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%104 = arith.mulf %arg23, %103 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%105 = arith.addf %101, %104 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%106 = arith.divf %cst, %105 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%107 = arith.mulf %104, %106 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
|
|
||||||
%108 = tt.expand_dims %107 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
|
|
||||||
%109 = tt.broadcast %108 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
|
|
||||||
%110 = arith.mulf %arg24, %109 : tensor<128x64xf32, #mma>
|
|
||||||
%111 = tt.expand_dims %106 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
|
|
||||||
%112 = tt.broadcast %111 : (tensor<128x1xf32, #mma>) -> tensor<128x128xf32, #mma>
|
|
||||||
%113 = arith.mulf %100, %112 : tensor<128x128xf32, #mma>
|
|
||||||
%114 = arith.truncf %113 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
|
|
||||||
%115 = triton_gpu.convert_layout %114 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
|
|
||||||
%116 = triton_gpu.alloc_tensor : tensor<1x128x64xf16, #shared0>
|
|
||||||
%117 = triton_gpu.insert_slice_async %arg27, %116, %c0_i32 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64x!tt.ptr<f16>, #blocked1> -> tensor<1x128x64xf16, #shared0>
|
|
||||||
triton_gpu.async_wait {num = 0 : i32}
|
|
||||||
%118 = tensor.extract_slice %117[0, 0, 0] [1, 128, 64] [1, 1, 1] : tensor<1x128x64xf16, #shared0> to tensor<128x64xf16, #shared0>
|
|
||||||
%119 = triton_gpu.convert_layout %118 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>
|
|
||||||
%120 = tt.dot %115, %119, %110 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x64xf32, #mma>
|
|
||||||
%121 = tt.addptr %arg26, %45 : tensor<64x128x!tt.ptr<f16>, #blocked2>, tensor<64x128xi32, #blocked2>
|
|
||||||
%122 = tt.addptr %arg27, %47 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
scf.yield %105, %120, %96, %121, %122 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<64x128x!tt.ptr<f16>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
}
|
|
||||||
%203 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #mma_s1>
|
|
||||||
%206 = tt.splat %2 : (i32) -> tensor<128xi32, #mma_s1>
|
|
||||||
%209 = arith.addi %206, %203 : tensor<128xi32, #mma_s1>
|
|
||||||
%61 = arith.muli %1, %arg21 : i32
|
|
||||||
%62 = tt.addptr %arg4, %61 : !tt.ptr<f32>, i32
|
|
||||||
%63 = tt.splat %62 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #mma_s1>
|
|
||||||
%64 = tt.addptr %63, %209 : tensor<128x!tt.ptr<f32>, #mma_s1>, tensor<128xi32, #mma_s1>
|
|
||||||
%65 = tt.addptr %arg5, %61 : !tt.ptr<f32>, i32
|
|
||||||
%66 = tt.splat %65 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #mma_s1>
|
|
||||||
%67 = tt.addptr %66, %209 : tensor<128x!tt.ptr<f32>, #mma_s1>, tensor<128xi32, #mma_s1>
|
|
||||||
tt.store %64, %57#0 : tensor<128xf32, #mma_s1>
|
|
||||||
tt.store %67, %57#2 : tensor<128xf32, #mma_s1>
|
|
||||||
%68 = arith.muli %1, %arg17 : i32
|
|
||||||
%69 = tt.splat %arg18 : (i32) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%70 = tt.splat %68 : (i32) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%71 = tt.splat %arg6 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%72 = arith.muli %32, %69 : tensor<128x1xi32, #blocked1>
|
|
||||||
%73 = arith.addi %70, %72 : tensor<128x1xi32, #blocked1>
|
|
||||||
%74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%75 = arith.addi %74, %17 : tensor<128x64xi32, #blocked1>
|
|
||||||
%76 = tt.addptr %71, %75 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%77 = arith.truncf %57#1 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
|
|
||||||
// TODO: conversion should be here, not right after the loop
|
|
||||||
%78 = triton_gpu.convert_layout %77 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #blocked1>
|
|
||||||
tt.store %76, %78 : tensor<128x64xf16, #blocked1>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,168 +0,0 @@
|
|||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
|
||||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
|
|
||||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
|
||||||
#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}>
|
|
||||||
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}>
|
|
||||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
|
|
||||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
|
|
||||||
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|
||||||
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
|
|
||||||
%c0 = arith.constant 0 : index
|
|
||||||
%c1 = arith.constant 1 : index
|
|
||||||
%c128_i32 = arith.constant 128 : i32
|
|
||||||
%c128 = arith.constant 128 : index
|
|
||||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
|
|
||||||
%cst_0 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
|
|
||||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
|
|
||||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
||||||
%1 = arith.divsi %0, %arg22 : i32
|
|
||||||
%2 = arith.remsi %0, %arg22 : i32
|
|
||||||
%3 = arith.muli %1, %arg12 : i32
|
|
||||||
%4 = arith.muli %2, %arg13 : i32
|
|
||||||
%5 = arith.addi %3, %4 : i32
|
|
||||||
%6 = tt.addptr %arg0, %5 : !tt.ptr<f16>, i32
|
|
||||||
%7 = tt.addptr %arg1, %5 : !tt.ptr<f16>, i32
|
|
||||||
%8 = tt.addptr %arg2, %5 : !tt.ptr<f16>, i32
|
|
||||||
%9 = tt.addptr %arg5, %5 : !tt.ptr<f16>, i32
|
|
||||||
%10 = tt.addptr %arg6, %5 : !tt.ptr<f32>, i32
|
|
||||||
%11 = tt.addptr %arg7, %5 : !tt.ptr<f16>, i32
|
|
||||||
%12 = tt.addptr %arg8, %5 : !tt.ptr<f16>, i32
|
|
||||||
%13 = arith.index_cast %arg24 : i32 to index
|
|
||||||
%14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0>
|
|
||||||
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
|
||||||
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
|
||||||
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%19 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked2>
|
|
||||||
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
|
||||||
%22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
|
||||||
%23 = tt.expand_dims %21 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
|
||||||
%24 = tt.broadcast %23 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%25 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
|
||||||
%26 = tt.broadcast %25 : (tensor<1x64xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
|
||||||
%27 = tt.splat %6 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%28 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%29 = tt.splat %7 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%30 = tt.splat %8 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%31 = tt.splat %9 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%32 = tt.splat %10 : (!tt.ptr<f32>) -> tensor<128x64x!tt.ptr<f32>, #blocked2>
|
|
||||||
%33 = arith.muli %0, %arg23 : i32
|
|
||||||
%34 = tt.addptr %arg11, %33 : !tt.ptr<f32>, i32
|
|
||||||
%35 = tt.addptr %arg10, %33 : !tt.ptr<f32>, i32
|
|
||||||
%36 = arith.muli %arg24, %c128_i32 : i32
|
|
||||||
%37 = arith.index_cast %36 : i32 to index
|
|
||||||
%38 = tt.splat %35 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
|
||||||
%39 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma0>
|
|
||||||
%40 = tt.splat %34 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
|
||||||
%41 = arith.muli %arg14, %c128_i32 : i32
|
|
||||||
%42 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%43 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked2>
|
|
||||||
%44 = tt.splat %12 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%45 = tt.splat %11 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
scf.for %arg25 = %c0 to %13 step %c1 {
|
|
||||||
%46 = arith.index_cast %arg25 : index to i32
|
|
||||||
%47 = arith.muli %46, %c128_i32 : i32
|
|
||||||
%48 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%49 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
|
||||||
%50 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
|
||||||
%51 = arith.addi %48, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%52 = arith.addi %49, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
|
||||||
%53 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%54 = tt.expand_dims %52 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2>
|
|
||||||
%55 = arith.muli %53, %28 : tensor<128x1xi32, #blocked1>
|
|
||||||
%56 = tt.broadcast %55 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%57 = arith.addi %56, %24 : tensor<128x64xi32, #blocked1>
|
|
||||||
%58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%60 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
||||||
%61 = tt.trans %60 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
|
||||||
%62 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1>
|
|
||||||
%63 = tt.broadcast %62 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%64 = arith.addi %63, %24 : tensor<128x64xi32, #blocked1>
|
|
||||||
%65 = tt.addptr %30, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%66 = tt.load %65 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%67 = triton_gpu.convert_layout %66 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
||||||
%68 = tt.trans %67 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
|
||||||
%69 = arith.index_cast %47 : i32 to index
|
|
||||||
%70 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
|
||||||
%71 = tt.expand_dims %70 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
|
|
||||||
%72 = tt.broadcast %71 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
|
||||||
%73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
|
|
||||||
%74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
|
||||||
%75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2>
|
|
||||||
%76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
|
||||||
%77 = tt.addptr %27, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%78 = tt.addptr %31, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%79:5 = scf.for %arg26 = %69 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
|
|
||||||
%86 = arith.index_cast %arg26 : index to i32
|
|
||||||
%87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0>
|
|
||||||
%88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%89 = arith.addi %87, %14 : tensor<128xi32, #blocked0>
|
|
||||||
%90 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%91 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
||||||
%92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
|
||||||
%93 = triton_gpu.convert_layout %61 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
|
||||||
%94 = tt.dot %92, %93, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
|
||||||
%95 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
|
|
||||||
%97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
|
||||||
%98 = "triton_gpu.cmpi"(%97, %72) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
|
|
||||||
%99 = "triton_gpu.select"(%98, %94, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
||||||
%100 = tt.addptr %38, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
|
||||||
%101 = tt.load %100 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
|
||||||
%102 = arith.mulf %99, %39 : tensor<128x128xf32, #mma0>
|
|
||||||
%103 = triton_gpu.convert_layout %101 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%104 = tt.expand_dims %103 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
|
||||||
%105 = tt.broadcast %104 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
||||||
%106 = arith.subf %102, %105 : tensor<128x128xf32, #mma0>
|
|
||||||
%107 = math.exp %106 : tensor<128x128xf32, #mma0>
|
|
||||||
%108 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%109 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
||||||
%110 = arith.truncf %107 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
|
||||||
%111 = triton_gpu.convert_layout %110 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
|
||||||
%112 = tt.trans %111 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
|
||||||
%113 = triton_gpu.convert_layout %112 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
||||||
%114 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
||||||
%115 = tt.dot %113, %114, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
||||||
%116 = tt.addptr %40, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
|
||||||
%117 = tt.load %116 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
|
||||||
%118 = triton_gpu.convert_layout %117 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%119 = tt.expand_dims %118 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
|
||||||
%120 = tt.broadcast %119 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
||||||
%121 = arith.subf %cst_1, %120 : tensor<128x128xf32, #mma0>
|
|
||||||
%122 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
|
||||||
%123 = triton_gpu.convert_layout %68 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
|
||||||
%124 = tt.dot %122, %123, %121 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
|
||||||
%125 = arith.mulf %107, %124 : tensor<128x128xf32, #mma0>
|
|
||||||
%126 = arith.mulf %125, %39 : tensor<128x128xf32, #mma0>
|
|
||||||
%127 = arith.truncf %126 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
|
||||||
%128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
|
||||||
%129 = tt.trans %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
|
||||||
%130 = triton_gpu.convert_layout %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
||||||
%131 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
||||||
%132 = tt.dot %130, %131, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
||||||
%133 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
|
|
||||||
%134 = triton_gpu.convert_layout %133 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
|
|
||||||
%135 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
||||||
%136 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
||||||
%137 = tt.dot %135, %136, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
||||||
%138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
|
|
||||||
tt.store %arg29, %133 : tensor<128x64xf32, #blocked2>
|
|
||||||
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
|
||||||
%140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
scf.yield %115, %132, %139, %140, %141 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
}
|
|
||||||
%80 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
|
||||||
%81 = tt.addptr %44, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%82 = triton_gpu.convert_layout %80 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
|
||||||
tt.store %81, %82 : tensor<128x64xf16, #blocked1>
|
|
||||||
%83 = arith.truncf %79#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
|
||||||
%84 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%85 = triton_gpu.convert_layout %83 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
|
||||||
tt.store %84, %85 : tensor<128x64xf16, #blocked1>
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,178 +0,0 @@
|
|||||||
// TODO: swizzle
|
|
||||||
// TODO: move opIdx = 0 before opIdx = 1
|
|
||||||
// TODO: reuse %128 in %137 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
||||||
// TODO: don't convert loaded value to mma for accumulation
|
|
||||||
// triton-opt unoptimized.ttgir -tritongpu-sink-conversions-from-shared -tritongpu-decompose-conversions-to-dot-operand -cse
|
|
||||||
|
|
||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
|
||||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
|
|
||||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
|
||||||
#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}>
|
|
||||||
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}>
|
|
||||||
#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
|
||||||
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
|
|
||||||
#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
|
|
||||||
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|
||||||
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
|
|
||||||
%c0 = arith.constant 0 : index
|
|
||||||
%c1 = arith.constant 1 : index
|
|
||||||
%c128_i32 = arith.constant 128 : i32
|
|
||||||
%c128 = arith.constant 128 : index
|
|
||||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
|
|
||||||
%cst_0 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
|
|
||||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
|
|
||||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
||||||
%1 = arith.divsi %0, %arg22 : i32
|
|
||||||
%2 = arith.remsi %0, %arg22 : i32
|
|
||||||
%3 = arith.muli %1, %arg12 : i32
|
|
||||||
%4 = arith.muli %2, %arg13 : i32
|
|
||||||
%5 = arith.addi %3, %4 : i32
|
|
||||||
%6 = tt.addptr %arg0, %5 : !tt.ptr<f16>, i32
|
|
||||||
%7 = tt.addptr %arg1, %5 : !tt.ptr<f16>, i32
|
|
||||||
%8 = tt.addptr %arg2, %5 : !tt.ptr<f16>, i32
|
|
||||||
%9 = tt.addptr %arg5, %5 : !tt.ptr<f16>, i32
|
|
||||||
%10 = tt.addptr %arg6, %5 : !tt.ptr<f32>, i32
|
|
||||||
%11 = tt.addptr %arg7, %5 : !tt.ptr<f16>, i32
|
|
||||||
%12 = tt.addptr %arg8, %5 : !tt.ptr<f16>, i32
|
|
||||||
%13 = arith.index_cast %arg24 : i32 to index
|
|
||||||
%14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0>
|
|
||||||
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
|
||||||
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
|
||||||
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%19 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked2>
|
|
||||||
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
|
||||||
%22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
|
||||||
%23 = tt.expand_dims %21 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
|
||||||
%24 = tt.broadcast %23 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%25 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
|
||||||
%26 = tt.broadcast %25 : (tensor<1x64xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
|
||||||
%27 = tt.splat %6 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%28 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%29 = tt.splat %7 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%30 = tt.splat %8 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%31 = tt.splat %9 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%32 = tt.splat %10 : (!tt.ptr<f32>) -> tensor<128x64x!tt.ptr<f32>, #blocked2>
|
|
||||||
%33 = arith.muli %0, %arg23 : i32
|
|
||||||
%34 = tt.addptr %arg11, %33 : !tt.ptr<f32>, i32
|
|
||||||
%35 = tt.addptr %arg10, %33 : !tt.ptr<f32>, i32
|
|
||||||
%36 = arith.muli %arg24, %c128_i32 : i32
|
|
||||||
%37 = arith.index_cast %36 : i32 to index
|
|
||||||
%38 = tt.splat %35 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
|
||||||
%39 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma0>
|
|
||||||
%40 = tt.splat %34 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
|
||||||
%41 = arith.muli %arg14, %c128_i32 : i32
|
|
||||||
%42 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%43 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked2>
|
|
||||||
%44 = tt.splat %12 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
%45 = tt.splat %11 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
scf.for %arg25 = %c0 to %13 step %c1 {
|
|
||||||
%46 = arith.index_cast %arg25 : index to i32
|
|
||||||
%47 = arith.muli %46, %c128_i32 : i32
|
|
||||||
%48 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%49 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
|
||||||
%50 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
|
||||||
%51 = arith.addi %48, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%52 = arith.addi %49, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
|
||||||
%53 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
|
|
||||||
%54 = tt.expand_dims %52 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2>
|
|
||||||
%55 = arith.muli %53, %28 : tensor<128x1xi32, #blocked1>
|
|
||||||
%56 = tt.broadcast %55 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%57 = arith.addi %56, %24 : tensor<128x64xi32, #blocked1>
|
|
||||||
%58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%60 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
||||||
%61 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1>
|
|
||||||
%62 = tt.broadcast %61 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
||||||
%63 = arith.addi %62, %24 : tensor<128x64xi32, #blocked1>
|
|
||||||
%64 = tt.addptr %30, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%65 = tt.load %64 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%66 = triton_gpu.convert_layout %65 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
||||||
%67 = arith.index_cast %47 : i32 to index
|
|
||||||
%68 = tt.trans %60 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
|
||||||
%69 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
|
||||||
%70 = tt.expand_dims %69 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
|
|
||||||
%71 = tt.broadcast %70 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
|
||||||
%72 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
|
||||||
%73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
|
|
||||||
%74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
|
||||||
%75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2>
|
|
||||||
%76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
|
||||||
%77 = tt.addptr %27, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%78 = tt.addptr %31, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%79:5 = scf.for %arg26 = %67 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
|
|
||||||
%86 = arith.index_cast %arg26 : index to i32
|
|
||||||
%87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0>
|
|
||||||
%88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%89 = arith.addi %87, %14 : tensor<128xi32, #blocked0>
|
|
||||||
%90 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%91 = triton_gpu.convert_layout %68 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
|
||||||
%92 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2>
|
|
||||||
%93 = triton_gpu.convert_layout %92 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
|
||||||
%94 = tt.dot %93, %91, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
|
||||||
%95 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
|
|
||||||
%97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
|
||||||
%98 = "triton_gpu.cmpi"(%97, %71) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
|
|
||||||
%99 = "triton_gpu.select"(%98, %94, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
||||||
%100 = tt.addptr %38, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
|
||||||
%101 = tt.load %100 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
|
||||||
%102 = triton_gpu.convert_layout %101 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%103 = arith.mulf %99, %39 : tensor<128x128xf32, #mma0>
|
|
||||||
%104 = tt.expand_dims %102 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
|
||||||
%105 = tt.broadcast %104 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
||||||
%106 = arith.subf %103, %105 : tensor<128x128xf32, #mma0>
|
|
||||||
%107 = math.exp %106 : tensor<128x128xf32, #mma0>
|
|
||||||
%108 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
||||||
%109 = arith.truncf %107 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
|
||||||
%110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
|
||||||
%111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
|
||||||
%112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
||||||
%113 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2>
|
|
||||||
%114 = triton_gpu.convert_layout %113 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
||||||
%115 = tt.dot %112, %114, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
||||||
%116 = tt.addptr %40, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
|
||||||
%117 = tt.load %116 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
|
||||||
%118 = triton_gpu.convert_layout %117 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
|
||||||
%119 = tt.expand_dims %118 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
|
||||||
%120 = tt.broadcast %119 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
|
||||||
%121 = arith.subf %cst_1, %120 : tensor<128x128xf32, #mma0>
|
|
||||||
%122 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
|
||||||
%123 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2>
|
|
||||||
%124 = triton_gpu.convert_layout %123 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
|
||||||
%125 = tt.dot %124, %122, %121 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
|
||||||
%126 = arith.mulf %107, %125 : tensor<128x128xf32, #mma0>
|
|
||||||
%127 = arith.mulf %126, %39 : tensor<128x128xf32, #mma0>
|
|
||||||
%128 = arith.truncf %127 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
|
||||||
%129 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
|
||||||
%130 = tt.trans %129 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
|
||||||
%131 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2>
|
|
||||||
%132 = triton_gpu.convert_layout %131 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
||||||
%133 = triton_gpu.convert_layout %130 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
||||||
%134 = tt.dot %133, %132, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
||||||
%135 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
|
|
||||||
%136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
|
|
||||||
%137 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
||||||
%138 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2>
|
|
||||||
%139 = triton_gpu.convert_layout %138 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
||||||
%140 = tt.dot %137, %139, %136 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
|
||||||
%141 = triton_gpu.convert_layout %140 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
|
|
||||||
tt.store %arg29, %141 : tensor<128x64xf32, #blocked2>
|
|
||||||
%142 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
|
||||||
%143 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%144 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
scf.yield %115, %134, %142, %143, %144 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
||||||
}
|
|
||||||
%80 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
|
||||||
%81 = tt.addptr %44, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%82 = triton_gpu.convert_layout %80 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
|
||||||
tt.store %81, %82 : tensor<128x64xf16, #blocked1>
|
|
||||||
%83 = arith.truncf %79#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
|
||||||
%84 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
||||||
%85 = triton_gpu.convert_layout %83 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
|
||||||
tt.store %84, %85 : tensor<128x64xf16, #blocked1>
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
Reference in New Issue
Block a user