y
This commit is contained in:
Philippe Tillet
2023-01-09 22:58:40 -08:00
parent 2fa0dfbce9
commit b162c44d59

View File

@@ -44,13 +44,13 @@ def _fwd_kernel(
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
q *= (q.to(tl.float32) * sm_scale).to(tl.float16)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
# -- compute qk ----
k = tl.load(k_ptrs)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# compute new m
m = tl.maximum(tl.max(qk, 1), m_prev)
@@ -345,7 +345,7 @@ configs = [triton.testing.Benchmark(
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
) for mode in ['fwd']]
) for mode in ['bwd']]
@triton.testing.perf_report(configs)