From 3fefcd78d4d0ffe27f9cf5dabddba9685c46499e Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 9 Jan 2023 16:29:45 -0800 Subject: [PATCH] . --- python/tutorials/06-fused-attention.py | 54 +++++++++++++------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 8892b5529..3f8a63eb4 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -191,7 +191,7 @@ def _bwd_kernel( tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) -# _bwd_kernel = triton.compile("./slow.ttgir", num_warps=8) +_bwd_kernel = triton.compile("./slow.ttgir", num_warps=8) # _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8) # _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8) # _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432) @@ -260,34 +260,34 @@ class _attention(torch.autograd.Function): BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) - # _bwd_kernel[(ctx.grid[1],1,1)]( - # q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, - # o.data_ptr(), do_scaled.data_ptr(), - # dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), - # l.data_ptr(), m.data_ptr(), - # delta.data_ptr(), - # q.stride(0), q.stride(1), q.stride(2), - # k.stride(0), k.stride(1), k.stride(2), - # v.stride(0), v.stride(1), v.stride(2), - # q.shape[0], q.shape[1], q.shape[2], - # ctx.grid[0] - # ) - - pgm = _bwd_kernel[(ctx.grid[1],)]( - q, k, v, ctx.sm_scale, - o, do_scaled, - dq, dk, dv, - l, m, - delta, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), + _bwd_kernel[(ctx.grid[1],1,1)]( + q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, + o.data_ptr(), do_scaled.data_ptr(), + dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), + l.data_ptr(), m.data_ptr(), + delta.data_ptr(), + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), q.shape[0], q.shape[1], q.shape[2], - ctx.grid[0], - BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, - num_stages=1, + ctx.grid[0] ) + + # pgm = _bwd_kernel[(ctx.grid[1],)]( + # q, k, v, ctx.sm_scale, + # o, do_scaled, + # dq, dk, dv, + # l, m, + # delta, + # q.stride(0), q.stride(1), q.stride(2), q.stride(3), + # k.stride(0), k.stride(1), k.stride(2), k.stride(3), + # v.stride(0), v.stride(1), v.stride(2), v.stride(3), + # q.shape[0], q.shape[1], q.shape[2], + # ctx.grid[0], + # BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, + # BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + # num_stages=1, + # ) # print(pgm.asm["ttgir"]) # exit() return dq, dk, dv, None