This commit is contained in:
Phil Tillet
2023-01-02 23:13:12 -08:00
parent 05920e0b8b
commit 5c01c567b9
3 changed files with 81 additions and 82 deletions

View File

@@ -483,7 +483,7 @@ public:
return op->getBlock() == cvt->getBlock() && return op->getBlock() == cvt->getBlock() &&
!(isa<triton::ReduceOp>(op) && !(isa<triton::ReduceOp>(op) &&
!op->getResult(0).getType().isa<RankedTensorType>()) && !op->getResult(0).getType().isa<RankedTensorType>()) &&
// !isa<triton::gpu::ConvertLayoutOp>(op) && !isa<triton::gpu::ConvertLayoutOp>(op) &&
!isa<scf::YieldOp>(op); !isa<scf::YieldOp>(op);
}; };
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter); mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);

View File

@@ -1,20 +1,19 @@
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}> #mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}>
#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}> #mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}>
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> #shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 8, order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 8 : i32} { module attributes {"triton_gpu.num-warps" = 8 : i32} {
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) { func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
%c0 = arith.constant 0 : index %c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index %c1 = arith.constant 1 : index
%c128_i32 = arith.constant 128 : i32 %c128_i32 = arith.constant 128 : i32
%c128 = arith.constant 128 : index %c128 = arith.constant 128 : index
%cst = arith.constant dense<0xFF800000> : tensor<128x128xf32, #blocked3> %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma0> %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma1> %cst_1 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
%0 = tt.get_program_id {axis = 0 : i32} : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.divsi %0, %arg22 : i32 %1 = arith.divsi %0, %arg22 : i32
%2 = arith.remsi %0, %arg22 : i32 %2 = arith.remsi %0, %arg22 : i32
@@ -31,9 +30,9 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%13 = arith.index_cast %arg24 : i32 to index %13 = arith.index_cast %arg24 : i32 to index
%14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> %14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0>
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%19 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1> %19 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1>
%20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked2> %20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked2>
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> %21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
@@ -54,7 +53,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%36 = arith.muli %arg24, %c128_i32 : i32 %36 = arith.muli %arg24, %c128_i32 : i32
%37 = arith.index_cast %36 : i32 to index %37 = arith.index_cast %36 : i32 to index
%38 = tt.splat %35 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0> %38 = tt.splat %35 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
%39 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #blocked3> %39 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma0>
%40 = tt.splat %34 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0> %40 = tt.splat %34 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
%41 = arith.muli %arg14, %c128_i32 : i32 %41 = arith.muli %arg14, %c128_i32 : i32
%42 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked1> %42 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked1>
@@ -65,7 +64,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%46 = arith.index_cast %arg25 : index to i32 %46 = arith.index_cast %arg25 : index to i32
%47 = arith.muli %46, %c128_i32 : i32 %47 = arith.muli %46, %c128_i32 : i32
%48 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %48 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%49 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> %49 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
%50 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %50 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%51 = arith.addi %48, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %51 = arith.addi %48, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%52 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %52 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
@@ -84,9 +83,9 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%65 = arith.index_cast %47 : i32 to index %65 = arith.index_cast %47 : i32 to index
%66 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> %66 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
%67 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> %67 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
%68 = arith.addi %49, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> %68 = arith.addi %49, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
%69 = tt.expand_dims %68 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>) -> tensor<1x128xi32, #blocked3> %69 = tt.expand_dims %68 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
%70 = tt.broadcast %69 : (tensor<1x128xi32, #blocked3>) -> tensor<128x128xi32, #blocked3> %70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
%71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> %71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
%72 = tt.trans %71 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> %72 = tt.trans %71 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
%73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2> %73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
@@ -95,74 +94,72 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2> %76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
%77 = tt.addptr %27, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %77 = tt.addptr %27, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%78 = tt.addptr %31, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %78 = tt.addptr %31, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%79 = triton_gpu.convert_layout %67 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %91 = triton_gpu.convert_layout %67 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%82:5 = scf.for %arg26 = %65 to %37 step %c128 iter_args(%arg27 = %cst_0, %arg28 = %cst_0, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma0>, tensor<128x64xf32, #mma0>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) { %79:5 = scf.for %arg26 = %65 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
%89 = arith.index_cast %arg26 : index to i32 %86 = arith.index_cast %arg26 : index to i32
%90 = tt.splat %89 : (i32) -> tensor<128xi32, #blocked0> %87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0>
%91 = tt.splat %89 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> %88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%92 = arith.addi %90, %14 : tensor<128xi32, #blocked0> %89 = arith.addi %87, %14 : tensor<128xi32, #blocked0>
%93 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %90 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%94 = triton_gpu.convert_layout %93 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> %92 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%95 = tt.dot %94, %79, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x128xf32, #mma1> %93 = tt.dot %92, %91, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%96 = triton_gpu.convert_layout %95 : (tensor<128x128xf32, #mma1>) -> tensor<128x128xf32, #blocked3> %94 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%97 = arith.addi %91, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> %95 = tt.expand_dims %94 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
%98 = tt.expand_dims %97 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>) -> tensor<128x1xi32, #blocked3> %96 = tt.broadcast %95 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
%99 = tt.broadcast %98 : (tensor<128x1xi32, #blocked3>) -> tensor<128x128xi32, #blocked3> %97 = "triton_gpu.cmpi"(%96, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
%100 = "triton_gpu.cmpi"(%99, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #blocked3>, tensor<128x128xi32, #blocked3>) -> tensor<128x128xi1, #blocked3> %98 = "triton_gpu.select"(%97, %93, %cst_1) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%101 = "triton_gpu.select"(%100, %96, %cst) : (tensor<128x128xi1, #blocked3>, tensor<128x128xf32, #blocked3>, tensor<128x128xf32, #blocked3>) -> tensor<128x128xf32, #blocked3> %99 = tt.addptr %38, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
%102 = tt.addptr %38, %92 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0> %100 = tt.load %99 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
%103 = tt.load %102 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> %101 = arith.mulf %98, %39 : tensor<128x128xf32, #mma0>
%104 = arith.mulf %101, %39 : tensor<128x128xf32, #blocked3> %102 = triton_gpu.convert_layout %100 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%105 = triton_gpu.convert_layout %103 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> %103 = tt.expand_dims %102 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
%106 = tt.expand_dims %105 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>) -> tensor<128x1xf32, #blocked3> %104 = tt.broadcast %103 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%107 = tt.broadcast %106 : (tensor<128x1xf32, #blocked3>) -> tensor<128x128xf32, #blocked3> %105 = arith.subf %101, %104 : tensor<128x128xf32, #mma0>
%108 = arith.subf %104, %107 : tensor<128x128xf32, #blocked3> %106 = math.exp %105 : tensor<128x128xf32, #mma0>
%109 = math.exp %108 : tensor<128x128xf32, #blocked3> %107 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%110 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %108 = arith.truncf %106 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%111 = arith.truncf %109 : tensor<128x128xf32, #blocked3> to tensor<128x128xf16, #blocked3> %109 = triton_gpu.convert_layout %108 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1>
%112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #blocked3>) -> tensor<128x128xf16, #shared1> %110 = tt.trans %109 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0>
%113 = tt.trans %112 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0> %111 = triton_gpu.convert_layout %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%114 = triton_gpu.convert_layout %113 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> %112 = triton_gpu.convert_layout %107 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%115 = triton_gpu.convert_layout %110 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> %113 = tt.dot %111, %112, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%116 = tt.dot %114, %115, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x64xf32, #mma0> %114 = tt.addptr %40, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
%117 = tt.addptr %40, %92 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0> %115 = tt.load %114 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
%118 = tt.load %117 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> %116 = triton_gpu.convert_layout %115 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%119 = triton_gpu.convert_layout %118 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> %117 = tt.expand_dims %116 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
%120 = tt.expand_dims %119 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>) -> tensor<128x1xf32, #mma1> %118 = tt.broadcast %117 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%121 = tt.broadcast %120 : (tensor<128x1xf32, #mma1>) -> tensor<128x128xf32, #mma1> %119 = arith.subf %cst_0, %118 : tensor<128x128xf32, #mma0>
%122 = arith.subf %cst_1, %121 : tensor<128x128xf32, #mma1> %120 = triton_gpu.convert_layout %107 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%123 = triton_gpu.convert_layout %110 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> %121 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%80 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %122 = tt.dot %120, %121, %119 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%124 = tt.dot %123, %80, %122 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x128xf32, #mma1> %123 = arith.mulf %106, %122 : tensor<128x128xf32, #mma0>
%125 = triton_gpu.convert_layout %124 : (tensor<128x128xf32, #mma1>) -> tensor<128x128xf32, #blocked3> %124 = arith.mulf %123, %39 : tensor<128x128xf32, #mma0>
%126 = arith.mulf %109, %125 : tensor<128x128xf32, #blocked3> %125 = arith.truncf %124 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%127 = arith.mulf %126, %39 : tensor<128x128xf32, #blocked3> %126 = triton_gpu.convert_layout %125 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1>
%128 = arith.truncf %127 : tensor<128x128xf32, #blocked3> to tensor<128x128xf16, #blocked3> %127 = tt.trans %126 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0>
%129 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #blocked3>) -> tensor<128x128xf16, #shared1> %128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%130 = tt.trans %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0> %129 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%131 = triton_gpu.convert_layout %130 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> %130 = tt.dot %128, %129, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%132 = triton_gpu.convert_layout %93 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> %131 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
%133 = tt.dot %131, %132, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x64xf32, #mma0> %132 = triton_gpu.convert_layout %131 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
%134 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> %133 = triton_gpu.convert_layout %125 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%135 = triton_gpu.convert_layout %134 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma0> %134 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%136 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #blocked3>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> %135 = tt.dot %133, %134, %132 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%81 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> %136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
%137 = tt.dot %136, %81, %135 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x64xf32, #mma0> tt.store %arg29, %136 : tensor<128x64xf32, #blocked2>
%138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma0>) -> tensor<128x64xf32, #blocked2> %137 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
tt.store %arg29, %138 : tensor<128x64xf32, #blocked2> %138 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2> %139 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> scf.yield %113, %130, %137, %138, %139 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
%141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
scf.yield %116, %133, %139, %140, %141 : tensor<128x64xf32, #mma0>, tensor<128x64xf32, #mma0>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
} }
%83 = triton_gpu.convert_layout %82#1 : (tensor<128x64xf32, #mma0>) -> tensor<128x64xf32, #blocked1> %80 = triton_gpu.convert_layout %79#1 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked1>
%84 = triton_gpu.convert_layout %82#0 : (tensor<128x64xf32, #mma0>) -> tensor<128x64xf32, #blocked1> %81 = triton_gpu.convert_layout %79#0 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked1>
%85 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %82 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%86 = arith.truncf %84 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1> %83 = arith.truncf %81 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1>
tt.store %85, %86 : tensor<128x64xf16, #blocked1> tt.store %82, %83 : tensor<128x64xf16, #blocked1>
%87 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %84 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%88 = arith.truncf %83 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1> %85 = arith.truncf %80 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1>
tt.store %87, %88 : tensor<128x64xf16, #blocked1> tt.store %84, %85 : tensor<128x64xf16, #blocked1>
} }
return return
} }

View File

@@ -326,6 +326,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
triton.testing.assert_almost_equal(ref_dv, tri_dv) triton.testing.assert_almost_equal(ref_dv, tri_dv)
triton.testing.assert_almost_equal(ref_dk, tri_dk) triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq) triton.testing.assert_almost_equal(ref_dq, tri_dq)
print(ref_dk, tri_dk)
print(ref_dq, tri_dq)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4 # vary seq length for fixed head and batch=4