This commit is contained in:
Philippe Tillet
2023-01-09 22:11:00 -08:00
parent d88353a5a4
commit ff04a5e9b6
4 changed files with 88 additions and 35 deletions

View File

@@ -148,7 +148,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%136 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%137 = tt.dot %135, %136, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
tt.store %arg29, %138 : tensor<128x64xf32, #blocked2>
tt.store %arg29, %133 : tensor<128x64xf32, #blocked2>
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
%140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>

View File

@@ -191,6 +191,7 @@ def _bwd_kernel(
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
# _bwd_kernel = triton.compile("./slow.ttgir", num_warps=8)
# _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8)
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
@@ -260,7 +261,7 @@ class _attention(torch.autograd.Function):
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
# _bwd_kernel[(ctx.grid[1],1,1)](
# _bwd_kernel[(ctx.grid[1], 1, 1)](
# q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
# o.data_ptr(), do_scaled.data_ptr(),
# dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),