more optimizations
This commit is contained in:
@@ -27,7 +27,32 @@ class TritonGPUDecomposeConversionsToDotOperandPass
|
|||||||
public:
|
public:
|
||||||
TritonGPUDecomposeConversionsToDotOperandPass() = default;
|
TritonGPUDecomposeConversionsToDotOperandPass() = default;
|
||||||
|
|
||||||
void runOnOperation() override { return; }
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
ModuleOp mod = getOperation();
|
||||||
|
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||||
|
OpBuilder builder(cvtOp);
|
||||||
|
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||||
|
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||||
|
auto srcBlocked =
|
||||||
|
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||||
|
auto dstDotOp =
|
||||||
|
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||||
|
if (srcBlocked && dstDotOp) {
|
||||||
|
auto tmpType = RankedTensorType::get(
|
||||||
|
dstType.getShape(), dstType.getElementType(),
|
||||||
|
triton::gpu::SharedEncodingAttr::get(
|
||||||
|
mod.getContext(), dstDotOp, srcType.getShape(),
|
||||||
|
getOrder(srcBlocked), srcType.getElementType()));
|
||||||
|
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||||
|
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||||
|
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||||
|
cvtOp.getLoc(), dstType, tmp);
|
||||||
|
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||||
|
cvtOp.erase();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<Pass>
|
std::unique_ptr<Pass>
|
||||||
|
@@ -21,6 +21,17 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
static inline bool willIncreaseRegisterPressure(triton::gpu::ConvertLayoutOp op) {
|
||||||
|
auto srcType = op.getOperand().getType().cast<RankedTensorType>();
|
||||||
|
auto dstType = op.getResult().getType().cast<RankedTensorType>();
|
||||||
|
auto srcEncoding = srcType.getEncoding();
|
||||||
|
auto dstEncoding = dstType.getEncoding();
|
||||||
|
if(srcEncoding.isa<triton::gpu::SharedEncodingAttr>())
|
||||||
|
return true;
|
||||||
|
if(dstEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
class TritonGPUSinkConversionsFromSharedPass
|
class TritonGPUSinkConversionsFromSharedPass
|
||||||
: public TritonGPUSinkConversionsFromSharedBase<TritonGPUSinkConversionsFromSharedPass> {
|
: public TritonGPUSinkConversionsFromSharedBase<TritonGPUSinkConversionsFromSharedPass> {
|
||||||
@@ -28,6 +39,31 @@ public:
|
|||||||
TritonGPUSinkConversionsFromSharedPass() = default;
|
TritonGPUSinkConversionsFromSharedPass() = default;
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
ModuleOp m = getOperation();
|
||||||
|
// Move convert(load) immediately after dependent load
|
||||||
|
m.walk([&](triton::gpu::ConvertLayoutOp op){
|
||||||
|
auto load = dyn_cast<triton::LoadOp>(op.getOperand().getDefiningOp());
|
||||||
|
if(!load)
|
||||||
|
return;
|
||||||
|
op->moveAfter(load);
|
||||||
|
});
|
||||||
|
// Sink conversions into loops when they will increase
|
||||||
|
// register pressure
|
||||||
|
DenseMap<triton::gpu::ConvertLayoutOp, Operation *> opToMove;
|
||||||
|
m.walk([&](triton::gpu::ConvertLayoutOp op){
|
||||||
|
if(!willIncreaseRegisterPressure(op))
|
||||||
|
return;
|
||||||
|
auto user_begin = op->user_begin();
|
||||||
|
auto user_end = op->user_end();
|
||||||
|
if(std::distance(user_begin, user_end) != 1)
|
||||||
|
return;
|
||||||
|
opToMove.insert({op, *user_begin});
|
||||||
|
});
|
||||||
|
for(auto &kv: opToMove)
|
||||||
|
kv.first->moveBefore(kv.second);
|
||||||
|
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@@ -191,6 +191,7 @@ def _bwd_kernel(
|
|||||||
tl.store(dv_ptrs, dv)
|
tl.store(dv_ptrs, dv)
|
||||||
tl.store(dk_ptrs, dk)
|
tl.store(dk_ptrs, dk)
|
||||||
|
|
||||||
|
_bwd_kernel = triton.compile("./being-optimized.ttgir", num_warps=8)
|
||||||
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
|
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
|
||||||
# _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432)
|
# _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432)
|
||||||
|
|
||||||
@@ -258,34 +259,34 @@ class _attention(torch.autograd.Function):
|
|||||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||||
)
|
)
|
||||||
|
|
||||||
# _bwd_kernel[(ctx.grid[1],1,1)](
|
_bwd_kernel[(ctx.grid[1],1,1)](
|
||||||
# q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
|
q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
|
||||||
# o.data_ptr(), do_scaled.data_ptr(),
|
o.data_ptr(), do_scaled.data_ptr(),
|
||||||
# dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
|
dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
|
||||||
# l.data_ptr(), m.data_ptr(),
|
l.data_ptr(), m.data_ptr(),
|
||||||
# delta.data_ptr(),
|
delta.data_ptr(),
|
||||||
# q.stride(0), q.stride(1), q.stride(2),
|
q.stride(0), q.stride(1), q.stride(2),
|
||||||
# k.stride(0), k.stride(1), k.stride(2),
|
k.stride(0), k.stride(1), k.stride(2),
|
||||||
# v.stride(0), v.stride(1), v.stride(2),
|
v.stride(0), v.stride(1), v.stride(2),
|
||||||
# q.shape[0], q.shape[1], q.shape[2],
|
|
||||||
# ctx.grid[0]
|
|
||||||
# )
|
|
||||||
|
|
||||||
pgm = _bwd_kernel[(ctx.grid[1],)](
|
|
||||||
q, k, v, ctx.sm_scale,
|
|
||||||
o, do_scaled,
|
|
||||||
dq, dk, dv,
|
|
||||||
l, m,
|
|
||||||
delta,
|
|
||||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
|
||||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
|
||||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
|
||||||
q.shape[0], q.shape[1], q.shape[2],
|
q.shape[0], q.shape[1], q.shape[2],
|
||||||
ctx.grid[0],
|
ctx.grid[0]
|
||||||
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
|
||||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
|
||||||
num_stages=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pgm = _bwd_kernel[(ctx.grid[1],)](
|
||||||
|
# q, k, v, ctx.sm_scale,
|
||||||
|
# o, do_scaled,
|
||||||
|
# dq, dk, dv,
|
||||||
|
# l, m,
|
||||||
|
# delta,
|
||||||
|
# q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||||
|
# k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||||
|
# v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||||
|
# q.shape[0], q.shape[1], q.shape[2],
|
||||||
|
# ctx.grid[0],
|
||||||
|
# BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
||||||
|
# BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||||
|
# num_stages=1,
|
||||||
|
# )
|
||||||
# print(pgm.asm["ttgir"])
|
# print(pgm.asm["ttgir"])
|
||||||
# exit()
|
# exit()
|
||||||
return dq, dk, dv, None
|
return dq, dk, dv, None
|
||||||
@@ -380,4 +381,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
|||||||
return ms
|
return ms
|
||||||
|
|
||||||
|
|
||||||
# bench_flash_attention.run(save_path='.', print_data=True)
|
bench_flash_attention.run(save_path='.', print_data=True)
|
||||||
|
172
python/unoptimized.ttgir
Normal file
172
python/unoptimized.ttgir
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
// TODO: swizzle
|
||||||
|
// TODO: move opIdx = 0 before opIdx = 1
|
||||||
|
// TODO: reuse %128 in %137 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
||||||
|
// TODO: don't convert loaded value to mma for accumulation
|
||||||
|
// triton-opt unoptimized.ttgir -tritongpu-sink-conversions-from-shared -tritongpu-decompose-conversions-to-dot-operand -cse
|
||||||
|
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
||||||
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||||
|
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||||
|
#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}>
|
||||||
|
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}>
|
||||||
|
#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||||
|
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
||||||
|
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c128_i32 = arith.constant 128 : i32
|
||||||
|
%c128 = arith.constant 128 : index
|
||||||
|
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
|
||||||
|
%cst_0 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
|
||||||
|
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
|
||||||
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||||
|
%1 = arith.divsi %0, %arg22 : i32
|
||||||
|
%2 = arith.remsi %0, %arg22 : i32
|
||||||
|
%3 = arith.muli %1, %arg12 : i32
|
||||||
|
%4 = arith.muli %2, %arg13 : i32
|
||||||
|
%5 = arith.addi %3, %4 : i32
|
||||||
|
%6 = tt.addptr %arg0, %5 : !tt.ptr<f16>, i32
|
||||||
|
%7 = tt.addptr %arg1, %5 : !tt.ptr<f16>, i32
|
||||||
|
%8 = tt.addptr %arg2, %5 : !tt.ptr<f16>, i32
|
||||||
|
%9 = tt.addptr %arg5, %5 : !tt.ptr<f16>, i32
|
||||||
|
%10 = tt.addptr %arg6, %5 : !tt.ptr<f32>, i32
|
||||||
|
%11 = tt.addptr %arg7, %5 : !tt.ptr<f16>, i32
|
||||||
|
%12 = tt.addptr %arg8, %5 : !tt.ptr<f16>, i32
|
||||||
|
%13 = arith.index_cast %arg24 : i32 to index
|
||||||
|
%14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0>
|
||||||
|
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||||
|
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||||
|
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
||||||
|
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
||||||
|
%19 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1>
|
||||||
|
%20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked2>
|
||||||
|
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||||
|
%22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||||
|
%23 = tt.expand_dims %21 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
||||||
|
%24 = tt.broadcast %23 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
||||||
|
%25 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
||||||
|
%26 = tt.broadcast %25 : (tensor<1x64xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
||||||
|
%27 = tt.splat %6 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
|
%28 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked1>
|
||||||
|
%29 = tt.splat %7 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
|
%30 = tt.splat %8 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
|
%31 = tt.splat %9 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
|
%32 = tt.splat %10 : (!tt.ptr<f32>) -> tensor<128x64x!tt.ptr<f32>, #blocked2>
|
||||||
|
%33 = arith.muli %0, %arg23 : i32
|
||||||
|
%34 = tt.addptr %arg11, %33 : !tt.ptr<f32>, i32
|
||||||
|
%35 = tt.addptr %arg10, %33 : !tt.ptr<f32>, i32
|
||||||
|
%36 = arith.muli %arg24, %c128_i32 : i32
|
||||||
|
%37 = arith.index_cast %36 : i32 to index
|
||||||
|
%38 = tt.splat %35 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
||||||
|
%39 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma0>
|
||||||
|
%40 = tt.splat %34 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
||||||
|
%41 = arith.muli %arg14, %c128_i32 : i32
|
||||||
|
%42 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked1>
|
||||||
|
%43 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked2>
|
||||||
|
%44 = tt.splat %12 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
|
%45 = tt.splat %11 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
|
scf.for %arg25 = %c0 to %13 step %c1 {
|
||||||
|
%46 = arith.index_cast %arg25 : index to i32
|
||||||
|
%47 = arith.muli %46, %c128_i32 : i32
|
||||||
|
%48 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||||
|
%49 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||||
|
%50 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
||||||
|
%51 = arith.addi %48, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||||
|
%52 = arith.addi %49, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||||
|
%53 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
|
||||||
|
%54 = tt.expand_dims %52 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2>
|
||||||
|
%55 = arith.muli %53, %28 : tensor<128x1xi32, #blocked1>
|
||||||
|
%56 = tt.broadcast %55 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
||||||
|
%57 = arith.addi %56, %24 : tensor<128x64xi32, #blocked1>
|
||||||
|
%58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
|
%59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
||||||
|
%60 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1>
|
||||||
|
%61 = tt.broadcast %60 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
||||||
|
%62 = arith.addi %61, %24 : tensor<128x64xi32, #blocked1>
|
||||||
|
%63 = tt.addptr %30, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
|
%64 = tt.load %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
||||||
|
%65 = arith.index_cast %47 : i32 to index
|
||||||
|
%66 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
||||||
|
%67 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
||||||
|
%68 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
|
||||||
|
%69 = tt.expand_dims %68 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
|
||||||
|
%70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
||||||
|
%71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
|
||||||
|
%72 = tt.trans %71 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
|
||||||
|
%73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
|
||||||
|
%74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
|
||||||
|
%75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2>
|
||||||
|
%76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
||||||
|
%77 = tt.addptr %27, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
|
%78 = tt.addptr %31, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
|
%79 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
||||||
|
%80:5 = scf.for %arg26 = %65 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
|
||||||
|
%87 = arith.index_cast %arg26 : index to i32
|
||||||
|
%88 = tt.splat %87 : (i32) -> tensor<128xi32, #blocked0>
|
||||||
|
%89 = tt.splat %87 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
||||||
|
%90 = arith.addi %88, %14 : tensor<128xi32, #blocked0>
|
||||||
|
%91 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
||||||
|
%92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
||||||
|
%93 = triton_gpu.convert_layout %67 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
||||||
|
%94 = tt.dot %92, %93, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
||||||
|
%95 = arith.addi %89, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
||||||
|
%96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
|
||||||
|
%97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
|
||||||
|
%98 = "triton_gpu.cmpi"(%97, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
|
||||||
|
%99 = "triton_gpu.select"(%98, %94, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
||||||
|
%100 = tt.addptr %38, %90 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
||||||
|
%101 = tt.load %100 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
||||||
|
%102 = arith.mulf %99, %39 : tensor<128x128xf32, #mma0>
|
||||||
|
%103 = triton_gpu.convert_layout %101 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
||||||
|
%104 = tt.expand_dims %103 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
||||||
|
%105 = tt.broadcast %104 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
||||||
|
%106 = arith.subf %102, %105 : tensor<128x128xf32, #mma0>
|
||||||
|
%107 = math.exp %106 : tensor<128x128xf32, #mma0>
|
||||||
|
%108 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
||||||
|
%109 = arith.truncf %107 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
||||||
|
%110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
||||||
|
%111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
||||||
|
%112 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
||||||
|
%113 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
||||||
|
%114 = tt.dot %113, %112, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
||||||
|
%115 = tt.addptr %40, %90 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
||||||
|
%116 = tt.load %115 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
|
||||||
|
%117 = triton_gpu.convert_layout %116 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
|
||||||
|
%118 = tt.expand_dims %117 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
|
||||||
|
%119 = tt.broadcast %118 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
|
||||||
|
%120 = arith.subf %cst_1, %119 : tensor<128x128xf32, #mma0>
|
||||||
|
%121 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
||||||
|
%122 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
||||||
|
%123 = tt.dot %121, %122, %120 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
|
||||||
|
%124 = arith.mulf %107, %123 : tensor<128x128xf32, #mma0>
|
||||||
|
%125 = arith.mulf %124, %39 : tensor<128x128xf32, #mma0>
|
||||||
|
%126 = arith.truncf %125 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
|
||||||
|
%127 = triton_gpu.convert_layout %126 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
|
||||||
|
%128 = tt.trans %127 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
|
||||||
|
%129 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
||||||
|
%130 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
||||||
|
%131 = tt.dot %130, %129, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
||||||
|
%132 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
|
||||||
|
%133 = triton_gpu.convert_layout %132 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
|
||||||
|
%134 = triton_gpu.convert_layout %126 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
||||||
|
%135 = tt.dot %134, %79, %133 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
||||||
|
%136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
|
||||||
|
tt.store %arg29, %136 : tensor<128x64xf32, #blocked2>
|
||||||
|
%137 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
||||||
|
%138 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
|
%139 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
|
scf.yield %114, %131, %137, %138, %139 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
|
}
|
||||||
|
%81 = arith.truncf %80#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
||||||
|
%82 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
|
%83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
||||||
|
tt.store %82, %83 : tensor<128x64xf16, #blocked1>
|
||||||
|
%84 = arith.truncf %80#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
||||||
|
%85 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
|
%86 = triton_gpu.convert_layout %84 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
||||||
|
tt.store %85, %86 : tensor<128x64xf16, #blocked1>
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user