111 lines
11 KiB
Plaintext
111 lines
11 KiB
Plaintext
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}>
|
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
|
|
#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}>
|
|
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}>
|
|
#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
|
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
|
|
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
|
|
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma1>
|
|
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma0>
|
|
%c128 = arith.constant 128 : index
|
|
%c0 = arith.constant 0 : index
|
|
%c128_i32 = arith.constant 128 : i32
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
%1 = arith.divsi %0, %arg22 : i32
|
|
%2 = arith.remsi %0, %arg22 : i32
|
|
%3 = arith.muli %1, %arg12 : i32
|
|
%4 = arith.muli %2, %arg13 : i32
|
|
%5 = arith.addi %3, %4 : i32
|
|
%6 = tt.addptr %arg0, %5 : !tt.ptr<f16>, i32
|
|
%7 = tt.addptr %arg1, %5 : !tt.ptr<f16>, i32
|
|
%8 = tt.addptr %arg2, %5 : !tt.ptr<f16>, i32
|
|
%9 = tt.addptr %arg5, %5 : !tt.ptr<f16>, i32
|
|
%10 = tt.addptr %arg6, %5 : !tt.ptr<f32>, i32
|
|
%11 = tt.addptr %arg7, %5 : !tt.ptr<f16>, i32
|
|
%12 = tt.addptr %arg8, %5 : !tt.ptr<f16>, i32
|
|
%13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
|
%14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
%15 = tt.expand_dims %13 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<128x1xi32, #blocked0>
|
|
%16 = tt.expand_dims %14 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
|
|
%17 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked0>
|
|
%18 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1>
|
|
%19 = arith.muli %15, %17 : tensor<128x1xi32, #blocked0>
|
|
%20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
|
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
|
%22 = tt.broadcast %19 : (tensor<128x1xi32, #blocked0>) -> tensor<128x64xi32, #blocked0>
|
|
%23 = tt.expand_dims %20 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x64xi32, #blocked0>
|
|
%24 = tt.broadcast %23 : (tensor<1x64xi32, #blocked0>) -> tensor<128x64xi32, #blocked0>
|
|
%25 = tt.expand_dims %21 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
|
%26 = tt.broadcast %25 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
%27 = arith.addi %22, %24 : tensor<128x64xi32, #blocked0>
|
|
%28 = tt.splat %6 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
%29 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked1>
|
|
%30 = tt.splat %7 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
%31 = tt.splat %8 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
%32 = tt.splat %9 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
%33 = tt.splat %10 : (!tt.ptr<f32>) -> tensor<128x64x!tt.ptr<f32>, #blocked0>
|
|
%34 = tt.addptr %33, %27 : tensor<128x64x!tt.ptr<f32>, #blocked0>, tensor<128x64xi32, #blocked0>
|
|
%35 = arith.muli %16, %29 : tensor<128x1xi32, #blocked1>
|
|
%36 = tt.broadcast %35 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
%37 = arith.addi %36, %26 : tensor<128x64xi32, #blocked1>
|
|
%38 = tt.addptr %30, %37 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
%39 = tt.load %38 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
%40 = arith.muli %16, %18 : tensor<128x1xi32, #blocked1>
|
|
%41 = tt.broadcast %40 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
|
%42 = arith.addi %41, %26 : tensor<128x64xi32, #blocked1>
|
|
%43 = tt.addptr %31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
%44 = tt.load %43 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
%45 = arith.muli %arg24, %c128_i32 : i32
|
|
%46 = arith.index_cast %45 : i32 to index
|
|
%47 = triton_gpu.convert_layout %39 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
%48 = tt.trans %47 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
|
%49 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma1>
|
|
%50 = triton_gpu.convert_layout %44 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
|
%51 = tt.trans %50 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
|
%52 = arith.muli %arg14, %c128_i32 : i32
|
|
%53 = tt.splat %52 : (i32) -> tensor<128x64xi32, #blocked0>
|
|
%54 = tt.splat %52 : (i32) -> tensor<128x64xi32, #blocked1>
|
|
%55 = tt.addptr %28, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
%56 = tt.addptr %32, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
%57 = triton_gpu.convert_layout %48 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
%58 = triton_gpu.convert_layout %51 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
|
%59:5 = scf.for %arg25 = %c0 to %46 step %c128 iter_args(%arg26 = %cst_0, %arg27 = %cst_0, %arg28 = %34, %arg29 = %55, %arg30 = %56) -> (tensor<128x64xf32, #mma0>, tensor<128x64xf32, #mma0>, tensor<128x64x!tt.ptr<f32>, #blocked0>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
|
|
%68 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
%69 = triton_gpu.convert_layout %68 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
%70 = tt.dot %69, %57, %cst {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x128xf32, #mma1>
|
|
%73 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
|
%74 = arith.truncf %70 : tensor<128x128xf32, #mma1> to tensor<128x128xf16, #mma1>
|
|
%75 = triton_gpu.convert_layout %74 : (tensor<128x128xf16, #mma1>) -> tensor<128x128xf16, #shared1>
|
|
%76 = tt.trans %75 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0>
|
|
%77 = triton_gpu.convert_layout %76 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
|
%78 = triton_gpu.convert_layout %73 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
|
%79 = tt.dot %77, %78, %arg26 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x64xf32, #mma0>
|
|
%80 = triton_gpu.convert_layout %73 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
%81 = tt.dot %80, %58, %cst {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x128xf32, #mma1>
|
|
%83 = arith.mulf %70, %81 : tensor<128x128xf32, #mma1>
|
|
%84 = arith.mulf %83, %49 : tensor<128x128xf32, #mma1>
|
|
%85 = arith.truncf %84 : tensor<128x128xf32, #mma1> to tensor<128x128xf16, #mma1>
|
|
%86 = triton_gpu.convert_layout %85 : (tensor<128x128xf16, #mma1>) -> tensor<128x128xf16, #shared1>
|
|
%87 = tt.trans %86 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0>
|
|
%88 = triton_gpu.convert_layout %87 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
|
%89 = triton_gpu.convert_layout %68 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
|
%90 = tt.dot %88, %89, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x64xf32, #mma0>
|
|
%91 = tt.addptr %arg28, %53 : tensor<128x64x!tt.ptr<f32>, #blocked0>, tensor<128x64xi32, #blocked0>
|
|
%92 = tt.addptr %arg29, %54 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
%93 = tt.addptr %arg30, %54 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
scf.yield %79, %arg27, %arg28, %arg29, %arg30 : tensor<128x64xf32, #mma0>, tensor<128x64xf32, #mma0>, tensor<128x64x!tt.ptr<f32>, #blocked0>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
}
|
|
%60 = triton_gpu.convert_layout %59#1 : (tensor<128x64xf32, #mma0>) -> tensor<128x64xf32, #blocked1>
|
|
%61 = triton_gpu.convert_layout %59#0 : (tensor<128x64xf32, #mma0>) -> tensor<128x64xf32, #blocked1>
|
|
%62 = tt.splat %12 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
%63 = tt.splat %11 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
|
%64 = tt.addptr %62, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
%65 = arith.truncf %61 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1>
|
|
tt.store %64, %65 : tensor<128x64xf16, #blocked1>
|
|
%66 = tt.addptr %63, %37 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
|
%67 = arith.truncf %60 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1>
|
|
tt.store %66, %67 : tensor<128x64xf16, #blocked1>
|
|
return
|
|
}
|
|
} |