more work
This commit is contained in:
@@ -50,7 +50,7 @@ public:
|
|||||||
});
|
});
|
||||||
// Sink conversions into loops when they will increase
|
// Sink conversions into loops when they will increase
|
||||||
// register pressure
|
// register pressure
|
||||||
DenseMap<triton::gpu::ConvertLayoutOp, Operation *> opToMove;
|
DenseMap<Operation*, Operation *> opToMove;
|
||||||
m.walk([&](triton::gpu::ConvertLayoutOp op){
|
m.walk([&](triton::gpu::ConvertLayoutOp op){
|
||||||
if(!willIncreaseRegisterPressure(op))
|
if(!willIncreaseRegisterPressure(op))
|
||||||
return;
|
return;
|
||||||
@@ -63,6 +63,15 @@ public:
|
|||||||
for(auto &kv: opToMove)
|
for(auto &kv: opToMove)
|
||||||
kv.first->moveBefore(kv.second);
|
kv.first->moveBefore(kv.second);
|
||||||
|
|
||||||
|
// Move transpositions just before their first use
|
||||||
|
opToMove.clear();
|
||||||
|
m.walk([&](triton::TransOp op){
|
||||||
|
auto user_begin = op->user_begin();
|
||||||
|
opToMove.insert({op, *user_begin});
|
||||||
|
});
|
||||||
|
for(auto &kv: opToMove)
|
||||||
|
kv.first->moveBefore(kv.second);
|
||||||
|
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@@ -1,8 +1,3 @@
|
|||||||
// TODO: swizzle
|
|
||||||
// TODO: move opIdx = 0 before opIdx = 1
|
|
||||||
// TODO: reuse %128 in %137 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
|
||||||
// don't convert loaded value to mma for accumulation
|
|
||||||
|
|
||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
||||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||||
|
@@ -191,7 +191,7 @@ def _bwd_kernel(
|
|||||||
tl.store(dv_ptrs, dv)
|
tl.store(dv_ptrs, dv)
|
||||||
tl.store(dk_ptrs, dk)
|
tl.store(dk_ptrs, dk)
|
||||||
|
|
||||||
_bwd_kernel = triton.compile("./being-optimized.ttgir", num_warps=8)
|
# _bwd_kernel = triton.compile("./slow.ttgir", num_warps=8)
|
||||||
# _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8)
|
# _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8)
|
||||||
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
|
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
|
||||||
# _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432)
|
# _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432)
|
||||||
@@ -260,34 +260,34 @@ class _attention(torch.autograd.Function):
|
|||||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||||
)
|
)
|
||||||
|
|
||||||
_bwd_kernel[(ctx.grid[1],1,1)](
|
# _bwd_kernel[(ctx.grid[1],1,1)](
|
||||||
q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
|
# q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
|
||||||
o.data_ptr(), do_scaled.data_ptr(),
|
# o.data_ptr(), do_scaled.data_ptr(),
|
||||||
dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
|
# dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
|
||||||
l.data_ptr(), m.data_ptr(),
|
# l.data_ptr(), m.data_ptr(),
|
||||||
delta.data_ptr(),
|
# delta.data_ptr(),
|
||||||
q.stride(0), q.stride(1), q.stride(2),
|
# q.stride(0), q.stride(1), q.stride(2),
|
||||||
k.stride(0), k.stride(1), k.stride(2),
|
# k.stride(0), k.stride(1), k.stride(2),
|
||||||
v.stride(0), v.stride(1), v.stride(2),
|
# v.stride(0), v.stride(1), v.stride(2),
|
||||||
q.shape[0], q.shape[1], q.shape[2],
|
|
||||||
ctx.grid[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
# pgm = _bwd_kernel[(ctx.grid[1],)](
|
|
||||||
# q, k, v, ctx.sm_scale,
|
|
||||||
# o, do_scaled,
|
|
||||||
# dq, dk, dv,
|
|
||||||
# l, m,
|
|
||||||
# delta,
|
|
||||||
# q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
|
||||||
# k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
|
||||||
# v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
|
||||||
# q.shape[0], q.shape[1], q.shape[2],
|
# q.shape[0], q.shape[1], q.shape[2],
|
||||||
# ctx.grid[0],
|
# ctx.grid[0]
|
||||||
# BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
|
||||||
# BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
|
||||||
# num_stages=1,
|
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
pgm = _bwd_kernel[(ctx.grid[1],)](
|
||||||
|
q, k, v, ctx.sm_scale,
|
||||||
|
o, do_scaled,
|
||||||
|
dq, dk, dv,
|
||||||
|
l, m,
|
||||||
|
delta,
|
||||||
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||||
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||||
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||||
|
q.shape[0], q.shape[1], q.shape[2],
|
||||||
|
ctx.grid[0],
|
||||||
|
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
||||||
|
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||||
|
num_stages=1,
|
||||||
|
)
|
||||||
# print(pgm.asm["ttgir"])
|
# print(pgm.asm["ttgir"])
|
||||||
# exit()
|
# exit()
|
||||||
return dq, dk, dv, None
|
return dq, dk, dv, None
|
||||||
|
Reference in New Issue
Block a user