.
y
This commit is contained in:
@@ -44,13 +44,13 @@ def _fwd_kernel(
|
|||||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||||
# load q: it will stay in SRAM throughout
|
# load q: it will stay in SRAM throughout
|
||||||
q = tl.load(q_ptrs)
|
q = tl.load(q_ptrs)
|
||||||
|
q *= (q.to(tl.float32) * sm_scale).to(tl.float16)
|
||||||
# loop over k, v and update accumulator
|
# loop over k, v and update accumulator
|
||||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||||
# -- compute qk ----
|
# -- compute qk ----
|
||||||
k = tl.load(k_ptrs)
|
k = tl.load(k_ptrs)
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
qk *= sm_scale
|
|
||||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||||
# compute new m
|
# compute new m
|
||||||
m = tl.maximum(tl.max(qk, 1), m_prev)
|
m = tl.maximum(tl.max(qk, 1), m_prev)
|
||||||
@@ -345,7 +345,7 @@ configs = [triton.testing.Benchmark(
|
|||||||
ylabel='ms',
|
ylabel='ms',
|
||||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||||
) for mode in ['fwd']]
|
) for mode in ['bwd']]
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(configs)
|
@triton.testing.perf_report(configs)
|
||||||
|
Reference in New Issue
Block a user