[BACKEND] Compiler improvements (#557)
This PR adds several optimization capabilities in the compiler backend: - Now using inline PTX for `tl.store`, making it possible to use things like evict_last - For A100, mma layout can be directly converted to shared memory - For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major. - Fixed liveness analysis; this was broken. - Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop. - `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
This commit is contained in:
@@ -840,10 +840,10 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
|
||||
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
|
||||
[(epilogue, allow_tf32, dtype)
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float32', 'int8']
|
||||
if not (allow_tf32 and (dtype == 'int8'))])
|
||||
for dtype in ['float16']
|
||||
if not (allow_tf32 and (dtype in ['float16']))])
|
||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80:
|
||||
@@ -852,21 +852,30 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
|
||||
M, N, K = 128, 128, 64
|
||||
num_warps = 8
|
||||
trans_a, trans_b = False, False
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xk,
|
||||
Y, stride_yk, stride_yn,
|
||||
W, stride_wn, stride_wl,
|
||||
Z, stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
ALLOW_TF32: tl.constexpr):
|
||||
ALLOW_TF32: tl.constexpr,
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
off_l = tl.arange(0, BLOCK_N)
|
||||
off_k = tl.arange(0, BLOCK_K)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
|
||||
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
||||
Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
z = tl.dot(tl.load(Xs), tl.load(Ys), allow_tf32=ALLOW_TF32)
|
||||
z = tl.dot(tl.load(Xs), tl.load(Ys), trans_a=TRANS_A, trans_b=TRANS_B, allow_tf32=ALLOW_TF32)
|
||||
if ADD_MATRIX:
|
||||
z += tl.load(Zs)
|
||||
if ADD_ROWS:
|
||||
@@ -875,39 +884,65 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
if ADD_COLS:
|
||||
ZCs = Z + off_n * stride_zn
|
||||
z += tl.load(ZCs)[None, :]
|
||||
if DO_SOFTMAX:
|
||||
max = tl.max(z, 1)
|
||||
z = z - max[:, None]
|
||||
num = tl.exp(z)
|
||||
den = tl.sum(num, 1)
|
||||
z = num / den[:, None]
|
||||
if CHAIN_DOT:
|
||||
# tl.store(Zs, z)
|
||||
# tl.debug_barrier()
|
||||
z = tl.dot(z.to(tl.float16), tl.load(Ws), trans_a=TRANS_A)
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
M, N, K = 64, 64, 32
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
|
||||
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
|
||||
x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
|
||||
y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
|
||||
w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
|
||||
if allow_tf32:
|
||||
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
x_tri = to_triton(x, device=device)
|
||||
y_tri = to_triton(y, device=device)
|
||||
w_tri = to_triton(w, device=device)
|
||||
# triton result
|
||||
z = numpy_random((M, N), dtype_str=dtype, rs=rs)
|
||||
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
|
||||
z_tri = to_triton(z, device=device)
|
||||
if epilogue == 'trans':
|
||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
TRANS_A=trans_a, TRANS_B=trans_b,
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
ADD_COLS=epilogue == 'add-cols',
|
||||
ALLOW_TF32=allow_tf32)
|
||||
DO_SOFTMAX=epilogue == 'softmax',
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
ALLOW_TF32=allow_tf32,
|
||||
num_warps=num_warps)
|
||||
# torch result
|
||||
z_ref = np.matmul(x, y)
|
||||
x_ref = x.T if trans_a else x
|
||||
y_ref = y.T if trans_b else y
|
||||
z_ref = np.matmul(x_ref, y_ref)
|
||||
if epilogue == 'add-matrix':
|
||||
z_ref += z
|
||||
if epilogue == 'add-rows':
|
||||
z_ref += z[:, 0][:, None]
|
||||
if epilogue == 'add-cols':
|
||||
z_ref += z[0, :][None, :]
|
||||
if epilogue == 'softmax':
|
||||
num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True))
|
||||
denom = np.sum(num, axis=-1, keepdims=True)
|
||||
z_ref = num / denom
|
||||
if epilogue == 'chain-dot':
|
||||
z_ref = np.matmul(z_ref.T if trans_a else z_ref, w)
|
||||
# compare
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
|
@@ -211,7 +211,7 @@ class ValueConstructor:
|
||||
return phi
|
||||
v = unique_handles.pop()
|
||||
phi.handle.replace_all_uses_with(v)
|
||||
phi.handle.erase_from_parent()
|
||||
# phi.handle.erase_from_parent()
|
||||
# TODO: remove trivial phis recursively
|
||||
return triton.language.tensor(v, phi.type)
|
||||
|
||||
|
@@ -732,7 +732,7 @@ def reshape(input, shape, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def dot(input, other, allow_tf32=True, _builder=None):
|
||||
def dot(input, other, trans_a=False, trans_b=False, allow_tf32=True, _builder=None):
|
||||
"""
|
||||
Returns the matrix product of two blocks.
|
||||
|
||||
@@ -744,7 +744,7 @@ def dot(input, other, allow_tf32=True, _builder=None):
|
||||
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
||||
"""
|
||||
allow_tf32 = _constexpr_to_value(allow_tf32)
|
||||
return semantic.dot(input, other, allow_tf32, _builder)
|
||||
return semantic.dot(input, other, trans_a, trans_b, allow_tf32, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
@@ -782,7 +782,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",
|
||||
|
||||
|
||||
@builtin
|
||||
def store(pointer, value, mask=None, _builder=None):
|
||||
def store(pointer, value, eviction_policy="", mask=None, _builder=None):
|
||||
"""
|
||||
Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
|
||||
|
||||
@@ -799,7 +799,7 @@ def store(pointer, value, mask=None, _builder=None):
|
||||
value = _to_tensor(value, _builder)
|
||||
if mask is not None:
|
||||
mask = _to_tensor(mask, _builder)
|
||||
return semantic.store(pointer, value, mask, _builder)
|
||||
return semantic.store(pointer, value, mask, eviction_policy, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
|
@@ -648,6 +648,18 @@ def cast(input: tl.tensor,
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
def _parse_eviction_policy(eviction_policy):
|
||||
eviction = ir.EVICTION_POLICY.NORMAL # default
|
||||
if eviction_policy:
|
||||
if eviction_policy == "evict_last":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_LAST
|
||||
elif eviction_policy == "evict_first":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_FIRST
|
||||
else:
|
||||
raise ValueError(f"Eviction policy {eviction_policy} not supported")
|
||||
return eviction
|
||||
|
||||
|
||||
def load(ptr: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
other: Optional[tl.tensor],
|
||||
@@ -684,14 +696,7 @@ def load(ptr: tl.tensor,
|
||||
raise ValueError(f"Cache modifier {cache_modifier} not supported")
|
||||
|
||||
# eviction policy
|
||||
eviction = ir.EVICTION_POLICY.NORMAL # default
|
||||
if eviction_policy:
|
||||
if eviction_policy == "evict_last":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_LAST
|
||||
elif eviction_policy == "evict_first":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_FIRST
|
||||
else:
|
||||
raise ValueError(f"Eviction policy {eviction_policy} not supported")
|
||||
eviction = _parse_eviction_policy(eviction_policy)
|
||||
|
||||
if ptr.type.is_block():
|
||||
shape = ptr.type.get_block_shapes()
|
||||
@@ -721,6 +726,7 @@ def load(ptr: tl.tensor,
|
||||
def store(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
eviction_policy: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
if not ptr.type.scalar.is_ptr():
|
||||
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
|
||||
@@ -735,14 +741,15 @@ def store(ptr: tl.tensor,
|
||||
elt_ty_ptr = tl.int8
|
||||
ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space)
|
||||
ptr = cast(ptr, ptr_ty, builder)
|
||||
|
||||
# eviction policy
|
||||
eviction = _parse_eviction_policy(eviction_policy)
|
||||
# cast to target data-type
|
||||
val = cast(val, elt_ty, builder)
|
||||
if not mask:
|
||||
return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void)
|
||||
return tl.tensor(builder.create_store(ptr.handle, val.handle, eviction), tl.void)
|
||||
if not mask.type.scalar.is_bool():
|
||||
raise ValueError("Mask must have boolean scalar type")
|
||||
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void)
|
||||
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, eviction), tl.void)
|
||||
|
||||
#########
|
||||
# atomic
|
||||
@@ -897,27 +904,31 @@ def atomic_xchg(ptr: tl.tensor,
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
def dot(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
def dot(a: tl.tensor,
|
||||
b: tl.tensor,
|
||||
trans_a: bool,
|
||||
trans_b: bool,
|
||||
allow_tf32: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
|
||||
assert lhs.shape[-1] == rhs.shape[0]
|
||||
assert lhs.shape[0] >= 16 and lhs.shape[1] >= 16 and rhs.shape[1] >= 16,\
|
||||
in_a = 1 if not trans_a else 0
|
||||
in_b = 1 if trans_b else 0
|
||||
assert a.type.is_block() and b.type.is_block()
|
||||
assert len(a.shape) == 2 and len(b.shape) == 2
|
||||
assert a.shape[in_a] == b.shape[in_b]
|
||||
assert a.shape[0] >= 16 and a.shape[1] >= 16 and b.shape[1] >= 16,\
|
||||
"small blocks not supported!"
|
||||
if lhs.type.scalar.is_int():
|
||||
if a.type.scalar.is_int():
|
||||
_0 = builder.get_int32(0)
|
||||
ret_scalar_ty = tl.int32
|
||||
else:
|
||||
_0 = builder.get_float32(0)
|
||||
ret_scalar_ty = tl.float32
|
||||
M = lhs.type.shape[0]
|
||||
N = rhs.type.shape[1]
|
||||
M = a.type.shape[in_a ^ 1]
|
||||
N = b.type.shape[in_b ^ 1]
|
||||
_0 = builder.create_splat(_0, [M, N])
|
||||
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||
ret_ty)
|
||||
ret = builder.create_dot(a.handle, b.handle, _0, trans_a, trans_b, allow_tf32)
|
||||
return tl.tensor(ret, ret_ty)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
198
python/tutorials/06-fused-attention.py
Normal file
198
python/tutorials/06-fused-attention.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V,
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kk, stride_kn,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_qm = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# initialize offsets
|
||||
offs_m = start_qm * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
|
||||
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
# Initialize pointers to Q, K, V
|
||||
q_ptrs = Q + off_q
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
# initialize pointer to m and l
|
||||
t_ptrs = TMP + off_hz * N_CTX + offs_m
|
||||
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
|
||||
q = tl.load(q_ptrs)
|
||||
for start_n in range(0, start_qm + 1):
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs)
|
||||
qk = tl.dot(q, k)
|
||||
qk += tl.where(offs_m[:, None] >= (start_n * BLOCK_N + offs_n[None, :]), 0, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
p = p.to(tl.float16)
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
tl.store(t_ptrs, acc_scale)
|
||||
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs)
|
||||
acc += tl.dot(p, v)
|
||||
k_ptrs += BLOCK_N * stride_kn
|
||||
v_ptrs += BLOCK_N * stride_vk
|
||||
# r_ptrs += BLOCK_N
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
start_qm = tl.program_id(0)
|
||||
offs_m = start_qm * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_i)
|
||||
tl.store(m_ptrs, m_i)
|
||||
# initialize pointers to output
|
||||
offs_n = tl.arange(0, BLOCK_DMODEL)
|
||||
off_out = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
||||
out_ptrs = Out + off_out
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk = q.shape[-1], k.shape[-2]
|
||||
assert Lq == Lk
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
|
||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
_fwd_kernel[grid](
|
||||
q, k, v,
|
||||
tmp, L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=64, num_warps=4,
|
||||
num_stages=1,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.BLOCK = BLOCK
|
||||
ctx.grid = grid
|
||||
return o
|
||||
|
||||
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_MODEL', [(2, 3, 1024, 64)])
|
||||
def test_op(Z, H, N_CTX, D_MODEL, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = .5 * torch.randn((Z, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = .5 * torch.randn((Z, H, D_MODEL, N_CTX), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = .5 * torch.randn((Z, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True)
|
||||
# triton implementation
|
||||
tri_out = attention(q, k, v)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
ref_qk = torch.matmul(q, k)
|
||||
for z in range(Z):
|
||||
for h in range(H):
|
||||
ref_qk[:, :, M == 0] = float("-inf")
|
||||
ref_qk = torch.softmax(ref_qk, dim=-1)
|
||||
ref_out = torch.matmul(ref_qk, v)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
HAS_FLASH = True
|
||||
except BaseException:
|
||||
HAS_FLASH = False
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 64, 2048, 64
|
||||
# vary batch size for fixed heads / seq
|
||||
batch_bench = triton.testing.Benchmark(
|
||||
x_names=['BATCH'],
|
||||
x_vals=[2**i for i in range(0, 8)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-seq{N_CTX}-head{N_HEADS}-d{D_HEAD}',
|
||||
args={'H': N_HEADS, 'N_CTX': N_CTX, 'D_MODEL': D_HEAD, 'dtype': torch.float16}
|
||||
)
|
||||
# vary seq length for fixed head and batch=4
|
||||
seq_bench = triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 16)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}',
|
||||
args={'H': D_HEAD, 'BATCH': BATCH, 'D_MODEL': D_HEAD, 'dtype': torch.float16}
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report([batch_bench, seq_bench])
|
||||
def bench_flash_attention(BATCH, H, N_CTX, D_MODEL, provider, dtype=torch.float16, device="cuda"):
|
||||
warmup = 25
|
||||
rep = 500
|
||||
if provider == "triton":
|
||||
q = torch.randn((BATCH, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, D_MODEL, N_CTX), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True)
|
||||
fn = lambda: attention(q, k, v)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
qkv = torch.randn((BATCH * N_CTX, 3, H, D_MODEL), dtype=dtype, device=device, requires_grad=True)
|
||||
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
|
||||
bench_flash_attention.run(save_path='.', print_data=True)
|
Reference in New Issue
Block a user