reduced some spilling
This commit is contained in:
@@ -483,7 +483,7 @@ public:
|
|||||||
return op->getBlock() == cvt->getBlock() &&
|
return op->getBlock() == cvt->getBlock() &&
|
||||||
!(isa<triton::ReduceOp>(op) &&
|
!(isa<triton::ReduceOp>(op) &&
|
||||||
!op->getResult(0).getType().isa<RankedTensorType>()) &&
|
!op->getResult(0).getType().isa<RankedTensorType>()) &&
|
||||||
!isa<triton::gpu::ConvertLayoutOp>(op) &&
|
// !isa<triton::gpu::ConvertLayoutOp>(op) &&
|
||||||
!isa<scf::YieldOp>(op);
|
!isa<scf::YieldOp>(op);
|
||||||
};
|
};
|
||||||
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
|
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
|
||||||
|
282
python/bwd.ttgir
282
python/bwd.ttgir
@@ -1,20 +1,20 @@
|
|||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
||||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||||
#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}>
|
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||||
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}>
|
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}>
|
||||||
|
#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}>
|
||||||
|
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}>
|
||||||
#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||||
// TODO: swizzle
|
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
|
||||||
#shared1 = #triton_gpu.shared<{vec = 2, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
|
||||||
|
|
||||||
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
||||||
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
|
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
|
||||||
%cst = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
|
|
||||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
|
|
||||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
|
|
||||||
%c128 = arith.constant 128 : index
|
|
||||||
%c128_i32 = arith.constant 128 : i32
|
|
||||||
%c1 = arith.constant 1 : index
|
|
||||||
%c0 = arith.constant 0 : index
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c128_i32 = arith.constant 128 : i32
|
||||||
|
%c128 = arith.constant 128 : index
|
||||||
|
%cst = arith.constant dense<0xFF800000> : tensor<128x128xf32, #blocked3>
|
||||||
|
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma0>
|
||||||
|
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma1>
|
||||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||||
%1 = arith.divsi %0, %arg22 : i32
|
%1 = arith.divsi %0, %arg22 : i32
|
||||||
%2 = arith.remsi %0, %arg22 : i32
|
%2 = arith.remsi %0, %arg22 : i32
|
||||||
@@ -29,134 +29,140 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|||||||
%11 = tt.addptr %arg7, %5 : !tt.ptr<f16>, i32
|
%11 = tt.addptr %arg7, %5 : !tt.ptr<f16>, i32
|
||||||
%12 = tt.addptr %arg8, %5 : !tt.ptr<f16>, i32
|
%12 = tt.addptr %arg8, %5 : !tt.ptr<f16>, i32
|
||||||
%13 = arith.index_cast %arg24 : i32 to index
|
%13 = arith.index_cast %arg24 : i32 to index
|
||||||
%14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
%14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0>
|
||||||
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||||
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
|
||||||
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||||
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>
|
||||||
%19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
%19 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1>
|
||||||
%20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked0>
|
%20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked2>
|
||||||
%21 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1>
|
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||||
%22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
%22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||||
%23 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
%23 = tt.expand_dims %21 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
||||||
%24 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x64xi32, #blocked0>
|
%24 = tt.broadcast %23 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
||||||
%25 = tt.broadcast %24 : (tensor<1x64xi32, #blocked0>) -> tensor<128x64xi32, #blocked0>
|
%25 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
||||||
%26 = tt.expand_dims %23 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
%26 = tt.broadcast %25 : (tensor<1x64xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
||||||
%27 = tt.broadcast %26 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
%27 = tt.splat %6 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
%28 = tt.splat %6 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked0>
|
%28 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked1>
|
||||||
%29 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked0>
|
%29 = tt.splat %7 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
%30 = tt.splat %7 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked0>
|
%30 = tt.splat %8 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
%31 = tt.splat %8 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked0>
|
%31 = tt.splat %9 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
%32 = tt.splat %9 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked0>
|
%32 = tt.splat %10 : (!tt.ptr<f32>) -> tensor<128x64x!tt.ptr<f32>, #blocked2>
|
||||||
%33 = tt.splat %10 : (!tt.ptr<f32>) -> tensor<128x64x!tt.ptr<f32>, #blocked1>
|
%33 = arith.muli %0, %arg23 : i32
|
||||||
%34 = arith.muli %arg24, %c128_i32 : i32
|
%34 = tt.addptr %arg11, %33 : !tt.ptr<f32>, i32
|
||||||
%35 = arith.index_cast %34 : i32 to index
|
%35 = tt.addptr %arg10, %33 : !tt.ptr<f32>, i32
|
||||||
%36 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma0>
|
%36 = arith.muli %arg24, %c128_i32 : i32
|
||||||
%37 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma0>
|
%37 = arith.index_cast %36 : i32 to index
|
||||||
%38 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma0>
|
%38 = tt.splat %35 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
||||||
%39 = arith.muli %arg14, %c128_i32 : i32
|
%39 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #blocked3>
|
||||||
%40 = tt.splat %39 : (i32) -> tensor<128x64xi32, #blocked0>
|
%40 = tt.splat %34 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
||||||
%41 = tt.splat %39 : (i32) -> tensor<128x64xi32, #blocked1>
|
%41 = arith.muli %arg14, %c128_i32 : i32
|
||||||
%42 = tt.splat %12 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked0>
|
%42 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked1>
|
||||||
%43 = tt.splat %11 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked0>
|
%43 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked2>
|
||||||
|
%44 = tt.splat %12 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
|
%45 = tt.splat %11 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
scf.for %arg25 = %c0 to %13 step %c1 {
|
scf.for %arg25 = %c0 to %13 step %c1 {
|
||||||
%44 = arith.index_cast %arg25 : index to i32
|
%46 = arith.index_cast %arg25 : index to i32
|
||||||
%45 = arith.muli %44, %c128_i32 : i32
|
%47 = arith.muli %46, %c128_i32 : i32
|
||||||
%46 = tt.splat %45 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
%48 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||||
%47 = tt.splat %45 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
%49 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
|
||||||
%48 = tt.splat %45 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
%50 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||||
%49 = tt.splat %45 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
%51 = arith.addi %48, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||||
%50 = arith.addi %46, %14 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
%52 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||||
%51 = arith.addi %49, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
||||||
%52 = tt.expand_dims %50 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<128x1xi32, #blocked0>
|
|
||||||
%53 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
|
%53 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
|
||||||
%54 = arith.muli %52, %29 : tensor<128x1xi32, #blocked0>
|
%54 = tt.expand_dims %52 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2>
|
||||||
%55 = tt.broadcast %54 : (tensor<128x1xi32, #blocked0>) -> tensor<128x64xi32, #blocked0>
|
%55 = arith.muli %53, %28 : tensor<128x1xi32, #blocked1>
|
||||||
%56 = arith.addi %55, %25 : tensor<128x64xi32, #blocked0>
|
%56 = tt.broadcast %55 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
||||||
%57 = tt.addptr %30, %56 : tensor<128x64x!tt.ptr<f16>, #blocked0>, tensor<128x64xi32, #blocked0>
|
%57 = arith.addi %56, %24 : tensor<128x64xi32, #blocked1>
|
||||||
%58 = tt.load %57 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked0>
|
%58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%59 = arith.muli %52, %20 : tensor<128x1xi32, #blocked0>
|
%59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
||||||
%60 = tt.broadcast %59 : (tensor<128x1xi32, #blocked0>) -> tensor<128x64xi32, #blocked0>
|
%60 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1>
|
||||||
%61 = arith.addi %60, %25 : tensor<128x64xi32, #blocked0>
|
%61 = tt.broadcast %60 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
||||||
%62 = tt.addptr %31, %61 : tensor<128x64x!tt.ptr<f16>, #blocked0>, tensor<128x64xi32, #blocked0>
|
%62 = arith.addi %61, %24 : tensor<128x64xi32, #blocked1>
|
||||||
%63 = tt.load %62 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked0>
|
%63 = tt.addptr %30, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%64 = arith.index_cast %45 : i32 to index
|
%64 = tt.load %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
||||||
%65 = triton_gpu.convert_layout %58 : (tensor<128x64xf16, #blocked0>) -> tensor<128x64xf16, #shared0>
|
%65 = arith.index_cast %47 : i32 to index
|
||||||
%66 = tt.trans %65 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
%66 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
||||||
%67 = arith.addi %47, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
%67 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
||||||
%68 = tt.expand_dims %67 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
|
%68 = arith.addi %49, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
|
||||||
%69 = tt.broadcast %68 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
%69 = tt.expand_dims %68 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>) -> tensor<1x128xi32, #blocked3>
|
||||||
%70 = arith.addi %48, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
%70 = tt.broadcast %69 : (tensor<1x128xi32, #blocked3>) -> tensor<128x128xi32, #blocked3>
|
||||||
%71 = tt.expand_dims %70 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
|
%71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
||||||
%72 = tt.broadcast %71 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
%72 = tt.trans %71 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
||||||
%73 = triton_gpu.convert_layout %63 : (tensor<128x64xf16, #blocked0>) -> tensor<128x64xf16, #shared0>
|
%73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
|
||||||
%74 = tt.trans %73 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
%74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
||||||
%75 = arith.muli %53, %21 : tensor<128x1xi32, #blocked1>
|
%75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2>
|
||||||
%76 = tt.broadcast %75 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
%76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
||||||
%77 = arith.addi %76, %27 : tensor<128x64xi32, #blocked1>
|
%77 = tt.addptr %27, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%78 = tt.addptr %33, %77 : tensor<128x64x!tt.ptr<f32>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%78 = tt.addptr %31, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%79 = tt.addptr %28, %61 : tensor<128x64x!tt.ptr<f16>, #blocked0>, tensor<128x64xi32, #blocked0>
|
%79 = triton_gpu.convert_layout %67 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
||||||
%80 = tt.addptr %32, %61 : tensor<128x64x!tt.ptr<f16>, #blocked0>, tensor<128x64xi32, #blocked0>
|
%82:5 = scf.for %arg26 = %65 to %37 step %c128 iter_args(%arg27 = %cst_0, %arg28 = %cst_0, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma0>, tensor<128x64xf32, #mma0>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
|
||||||
%81:5 = scf.for %arg26 = %64 to %35 step %c128 iter_args(%arg27 = %cst_1, %arg28 = %cst_1, %arg29 = %78, %arg30 = %79, %arg31 = %80) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked0>, tensor<128x64x!tt.ptr<f16>, #blocked0>) {
|
%89 = arith.index_cast %arg26 : index to i32
|
||||||
%88 = arith.index_cast %arg26 : index to i32
|
%90 = tt.splat %89 : (i32) -> tensor<128xi32, #blocked0>
|
||||||
%89 = tt.splat %88 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
%91 = tt.splat %89 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>
|
||||||
%90 = tt.splat %88 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
%92 = arith.addi %90, %14 : tensor<128xi32, #blocked0>
|
||||||
%91 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked0>
|
%93 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
||||||
%92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
%94 = triton_gpu.convert_layout %93 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
||||||
%93 = triton_gpu.convert_layout %66 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
%95 = tt.dot %94, %79, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x128xf32, #mma1>
|
||||||
%94 = tt.dot %92, %93, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
%96 = triton_gpu.convert_layout %95 : (tensor<128x128xf32, #mma1>) -> tensor<128x128xf32, #blocked3>
|
||||||
%95 = arith.addi %89, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
%97 = arith.addi %91, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>
|
||||||
%96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
|
%98 = tt.expand_dims %97 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>) -> tensor<128x1xi32, #blocked3>
|
||||||
%97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
%99 = tt.broadcast %98 : (tensor<128x1xi32, #blocked3>) -> tensor<128x128xi32, #blocked3>
|
||||||
%98 = "triton_gpu.cmpi"(%97, %69) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
|
%100 = "triton_gpu.cmpi"(%99, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #blocked3>, tensor<128x128xi32, #blocked3>) -> tensor<128x128xi1, #blocked3>
|
||||||
%99 = "triton_gpu.select"(%98, %94, %cst) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
%101 = "triton_gpu.select"(%100, %96, %cst) : (tensor<128x128xi1, #blocked3>, tensor<128x128xf32, #blocked3>, tensor<128x128xf32, #blocked3>) -> tensor<128x128xf32, #blocked3>
|
||||||
%100 = arith.mulf %99, %36 : tensor<128x128xf32, #mma0>
|
%102 = tt.addptr %38, %92 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
||||||
%101 = math.exp %100 : tensor<128x128xf32, #mma0>
|
%103 = tt.load %102 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
||||||
%102 = arith.addi %90, %19 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
%104 = arith.mulf %101, %39 : tensor<128x128xf32, #blocked3>
|
||||||
%103 = tt.expand_dims %102 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
|
%105 = triton_gpu.convert_layout %103 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>
|
||||||
%104 = tt.broadcast %103 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
%106 = tt.expand_dims %105 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>) -> tensor<128x1xf32, #blocked3>
|
||||||
%105 = "triton_gpu.cmpi"(%104, %72) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
|
%107 = tt.broadcast %106 : (tensor<128x1xf32, #blocked3>) -> tensor<128x128xf32, #blocked3>
|
||||||
%106 = "triton_gpu.select"(%105, %94, %cst) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
%108 = arith.subf %104, %107 : tensor<128x128xf32, #blocked3>
|
||||||
%107 = arith.mulf %106, %37 : tensor<128x128xf32, #mma0>
|
%109 = math.exp %108 : tensor<128x128xf32, #blocked3>
|
||||||
%108 = math.exp %107 : tensor<128x128xf32, #mma0>
|
%110 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
||||||
%109 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked0>
|
%111 = arith.truncf %109 : tensor<128x128xf32, #blocked3> to tensor<128x128xf16, #blocked3>
|
||||||
%110 = arith.truncf %101 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
%112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #blocked3>) -> tensor<128x128xf16, #shared1>
|
||||||
%111 = triton_gpu.convert_layout %110 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1>
|
%113 = tt.trans %112 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0>
|
||||||
%112 = tt.trans %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0>
|
%114 = triton_gpu.convert_layout %113 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
||||||
%113 = triton_gpu.convert_layout %112 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
%115 = triton_gpu.convert_layout %110 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
||||||
%114 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #blocked0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
%116 = tt.dot %114, %115, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x64xf32, #mma0>
|
||||||
%115 = tt.dot %113, %114, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
%117 = tt.addptr %40, %92 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
||||||
%116 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #blocked0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
%118 = tt.load %117 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
||||||
%117 = triton_gpu.convert_layout %74 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
%119 = triton_gpu.convert_layout %118 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
|
||||||
%118 = tt.dot %116, %117, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
%120 = tt.expand_dims %119 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>) -> tensor<128x1xf32, #mma1>
|
||||||
%119 = arith.mulf %108, %118 : tensor<128x128xf32, #mma0>
|
%121 = tt.broadcast %120 : (tensor<128x1xf32, #mma1>) -> tensor<128x128xf32, #mma1>
|
||||||
%120 = arith.mulf %119, %38 : tensor<128x128xf32, #mma0>
|
%122 = arith.subf %cst_1, %121 : tensor<128x128xf32, #mma1>
|
||||||
%121 = arith.truncf %120 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
%123 = triton_gpu.convert_layout %110 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
||||||
%122 = triton_gpu.convert_layout %121 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1>
|
%80 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
||||||
%123 = tt.trans %122 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0>
|
%124 = tt.dot %123, %80, %122 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x128xf32, #mma1>
|
||||||
%124 = triton_gpu.convert_layout %123 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
%125 = triton_gpu.convert_layout %124 : (tensor<128x128xf32, #mma1>) -> tensor<128x128xf32, #blocked3>
|
||||||
%125 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
%126 = arith.mulf %109, %125 : tensor<128x128xf32, #blocked3>
|
||||||
%126 = tt.dot %124, %125, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
%127 = arith.mulf %126, %39 : tensor<128x128xf32, #blocked3>
|
||||||
%127 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked1>
|
%128 = arith.truncf %127 : tensor<128x128xf32, #blocked3> to tensor<128x128xf16, #blocked3>
|
||||||
%128 = triton_gpu.convert_layout %127 : (tensor<128x64xf32, #blocked1>) -> tensor<128x64xf32, #mma1>
|
%129 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #blocked3>) -> tensor<128x128xf16, #shared1>
|
||||||
%129 = triton_gpu.convert_layout %121 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
%130 = tt.trans %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0>
|
||||||
%130 = triton_gpu.convert_layout %58 : (tensor<128x64xf16, #blocked0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
%131 = triton_gpu.convert_layout %130 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
||||||
%131 = tt.dot %129, %130, %128 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
%132 = triton_gpu.convert_layout %93 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
||||||
%132 = triton_gpu.convert_layout %131 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked1>
|
%133 = tt.dot %131, %132, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x64xf32, #mma0>
|
||||||
tt.store %arg29, %132 : tensor<128x64xf32, #blocked1>
|
%134 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
|
||||||
%133 = tt.addptr %arg29, %41 : tensor<128x64x!tt.ptr<f32>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%135 = triton_gpu.convert_layout %134 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma0>
|
||||||
%134 = tt.addptr %arg30, %40 : tensor<128x64x!tt.ptr<f16>, #blocked0>, tensor<128x64xi32, #blocked0>
|
%136 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #blocked3>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
||||||
%135 = tt.addptr %arg31, %40 : tensor<128x64x!tt.ptr<f16>, #blocked0>, tensor<128x64xi32, #blocked0>
|
%81 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
||||||
scf.yield %115, %126, %133, %134, %135 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked0>, tensor<128x64x!tt.ptr<f16>, #blocked0>
|
%137 = tt.dot %136, %81, %135 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x64xf32, #mma0>
|
||||||
|
%138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma0>) -> tensor<128x64xf32, #blocked2>
|
||||||
|
tt.store %arg29, %138 : tensor<128x64xf32, #blocked2>
|
||||||
|
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
||||||
|
%140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
|
%141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
|
scf.yield %116, %133, %139, %140, %141 : tensor<128x64xf32, #mma0>, tensor<128x64xf32, #mma0>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
}
|
}
|
||||||
%82 = triton_gpu.convert_layout %81#1 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked0>
|
%83 = triton_gpu.convert_layout %82#1 : (tensor<128x64xf32, #mma0>) -> tensor<128x64xf32, #blocked1>
|
||||||
%83 = triton_gpu.convert_layout %81#0 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked0>
|
%84 = triton_gpu.convert_layout %82#0 : (tensor<128x64xf32, #mma0>) -> tensor<128x64xf32, #blocked1>
|
||||||
%84 = tt.addptr %42, %61 : tensor<128x64x!tt.ptr<f16>, #blocked0>, tensor<128x64xi32, #blocked0>
|
%85 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%85 = arith.truncf %83 : tensor<128x64xf32, #blocked0> to tensor<128x64xf16, #blocked0>
|
%86 = arith.truncf %84 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1>
|
||||||
tt.store %84, %85 : tensor<128x64xf16, #blocked0>
|
tt.store %85, %86 : tensor<128x64xf16, #blocked1>
|
||||||
%86 = tt.addptr %43, %56 : tensor<128x64x!tt.ptr<f16>, #blocked0>, tensor<128x64xi32, #blocked0>
|
%87 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%87 = arith.truncf %82 : tensor<128x64xf32, #blocked0> to tensor<128x64xf16, #blocked0>
|
%88 = arith.truncf %83 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1>
|
||||||
tt.store %86, %87 : tensor<128x64xf16, #blocked0>
|
tt.store %87, %88 : tensor<128x64xf16, #blocked1>
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@@ -164,16 +164,14 @@ def _bwd_kernel(
|
|||||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||||
qk = tl.dot(q, tl.trans(k))
|
qk = tl.dot(q, tl.trans(k))
|
||||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||||
# m = tl.load(m_ptrs + offs_m_curr)
|
m = tl.load(m_ptrs + offs_m_curr)
|
||||||
# p = tl.exp(qk * sm_scale - m[:, None])
|
p = tl.exp(qk * sm_scale - m[:, None])
|
||||||
p = tl.exp(qk * sm_scale)
|
|
||||||
# compute dv
|
# compute dv
|
||||||
do = tl.load(do_ptrs)
|
do = tl.load(do_ptrs)
|
||||||
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
|
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
|
||||||
# compute dp = dot(v, do)
|
# compute dp = dot(v, do)
|
||||||
# Di = tl.load(D_ptrs + offs_m_curr)
|
Di = tl.load(D_ptrs + offs_m_curr)
|
||||||
# dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
||||||
dp += tl.dot(do, tl.trans(v))
|
dp += tl.dot(do, tl.trans(v))
|
||||||
# compute ds = p * (dp - delta[:, None])
|
# compute ds = p * (dp - delta[:, None])
|
||||||
ds = p * dp * sm_scale
|
ds = p * dp * sm_scale
|
||||||
@@ -287,7 +285,7 @@ class _attention(torch.autograd.Function):
|
|||||||
# num_stages=1,
|
# num_stages=1,
|
||||||
# )
|
# )
|
||||||
# print(pgm.asm["ttgir"])
|
# print(pgm.asm["ttgir"])
|
||||||
# exit(1)
|
# # exit(1)
|
||||||
return dq, dk, dv, None
|
return dq, dk, dv, None
|
||||||
|
|
||||||
|
|
||||||
@@ -326,8 +324,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|||||||
# compare
|
# compare
|
||||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||||
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||||
# triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||||
# triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||||
|
|
||||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||||
# vary seq length for fixed head and batch=4
|
# vary seq length for fixed head and batch=4
|
||||||
|
Reference in New Issue
Block a user