From 137e866bd2999eec9db650ac5bc6de14e3e8ce80 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 9 Jan 2023 16:20:10 -0800 Subject: [PATCH] more work --- .../Transforms/SinkConversionsFromShared.cpp | 11 +++- python/being-optimized.ttgir | 5 -- python/tutorials/06-fused-attention.py | 54 +++++++++---------- 3 files changed, 37 insertions(+), 33 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/SinkConversionsFromShared.cpp b/lib/Dialect/TritonGPU/Transforms/SinkConversionsFromShared.cpp index 62bcae01a..34c34daa3 100644 --- a/lib/Dialect/TritonGPU/Transforms/SinkConversionsFromShared.cpp +++ b/lib/Dialect/TritonGPU/Transforms/SinkConversionsFromShared.cpp @@ -50,7 +50,7 @@ public: }); // Sink conversions into loops when they will increase // register pressure - DenseMap opToMove; + DenseMap opToMove; m.walk([&](triton::gpu::ConvertLayoutOp op){ if(!willIncreaseRegisterPressure(op)) return; @@ -62,6 +62,15 @@ public: }); for(auto &kv: opToMove) kv.first->moveBefore(kv.second); + + // Move transpositions just before their first use + opToMove.clear(); + m.walk([&](triton::TransOp op){ + auto user_begin = op->user_begin(); + opToMove.insert({op, *user_begin}); + }); + for(auto &kv: opToMove) + kv.first->moveBefore(kv.second); return; diff --git a/python/being-optimized.ttgir b/python/being-optimized.ttgir index 270fdcf5a..038ca7d5e 100644 --- a/python/being-optimized.ttgir +++ b/python/being-optimized.ttgir @@ -1,8 +1,3 @@ -// TODO: swizzle -// TODO: move opIdx = 0 before opIdx = 1 -// TODO: reuse %128 in %137 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> -// don't convert loaded value to mma for accumulation - #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 5a9a1d72f..8892b5529 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -191,7 +191,7 @@ def _bwd_kernel( tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) -_bwd_kernel = triton.compile("./being-optimized.ttgir", num_warps=8) +# _bwd_kernel = triton.compile("./slow.ttgir", num_warps=8) # _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8) # _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8) # _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432) @@ -260,34 +260,34 @@ class _attention(torch.autograd.Function): BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) - _bwd_kernel[(ctx.grid[1],1,1)]( - q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, - o.data_ptr(), do_scaled.data_ptr(), - dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), - l.data_ptr(), m.data_ptr(), - delta.data_ptr(), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - q.shape[0], q.shape[1], q.shape[2], - ctx.grid[0] - ) - - # pgm = _bwd_kernel[(ctx.grid[1],)]( - # q, k, v, ctx.sm_scale, - # o, do_scaled, - # dq, dk, dv, - # l, m, - # delta, - # q.stride(0), q.stride(1), q.stride(2), q.stride(3), - # k.stride(0), k.stride(1), k.stride(2), k.stride(3), - # v.stride(0), v.stride(1), v.stride(2), v.stride(3), + # _bwd_kernel[(ctx.grid[1],1,1)]( + # q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, + # o.data_ptr(), do_scaled.data_ptr(), + # dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), + # l.data_ptr(), m.data_ptr(), + # delta.data_ptr(), + # q.stride(0), q.stride(1), q.stride(2), + # k.stride(0), k.stride(1), k.stride(2), + # v.stride(0), v.stride(1), v.stride(2), # q.shape[0], q.shape[1], q.shape[2], - # ctx.grid[0], - # BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, - # BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, - # num_stages=1, + # ctx.grid[0] # ) + + pgm = _bwd_kernel[(ctx.grid[1],)]( + q, k, v, ctx.sm_scale, + o, do_scaled, + dq, dk, dv, + l, m, + delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + q.shape[0], q.shape[1], q.shape[2], + ctx.grid[0], + BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + num_stages=1, + ) # print(pgm.asm["ttgir"]) # exit() return dq, dk, dv, None