.
This commit is contained in:
@@ -148,7 +148,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
||||
%136 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
||||
%137 = tt.dot %135, %136, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
|
||||
%138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
|
||||
tt.store %arg29, %138 : tensor<128x64xf32, #blocked2>
|
||||
tt.store %arg29, %133 : tensor<128x64xf32, #blocked2>
|
||||
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
|
||||
%140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||
%141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||
|
@@ -191,6 +191,7 @@ def _bwd_kernel(
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
# _bwd_kernel = triton.compile("./slow.ttgir", num_warps=8)
|
||||
# _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8)
|
||||
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
|
||||
@@ -260,7 +261,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
|
||||
# _bwd_kernel[(ctx.grid[1],1,1)](
|
||||
# _bwd_kernel[(ctx.grid[1], 1, 1)](
|
||||
# q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
|
||||
# o.data_ptr(), do_scaled.data_ptr(),
|
||||
# dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
|
||||
|
Reference in New Issue
Block a user