[OPS][BLOCKSPARSE] Improved robustness, clarity and performance (#450)
* dds layout now internally re-uses dsd code path for increased code * at_mask and kp_mask related things are now dropped from the softmax API. I couldn't think of any case where it was needed beyond is_causal. And if there is any, we should probably find a way to get it implemented statically so that users don't have to materialize masks. * fixed bug in blocksparse matmul that caused troubles when layout had a full row/col of zeros * blocksparse softmax now no longer modifies any data in-place * blocksparse softmax now takes an is_dense arguments that provides better performance. Passing is_dense=True, is_causal=True is the best way to achieve triangular attention. * unit tests now test backward pass
This commit is contained in:
@@ -10,77 +10,108 @@ import triton
|
||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
|
||||
@pytest.mark.parametrize("DTYPE", [torch.float16])
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
seed = 0
|
||||
torch.manual_seed(seed)
|
||||
is_sdd = MODE == "sdd"
|
||||
is_dsd = MODE == "dsd"
|
||||
is_dds = MODE == "dds"
|
||||
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
|
||||
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
|
||||
# create inputs
|
||||
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
|
||||
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
|
||||
# create op
|
||||
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
|
||||
b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
|
||||
c_shape = (Z, H, M, N)
|
||||
shape = {
|
||||
"sdd": (M, N),
|
||||
"dsd": (a.shape[2], a.shape[3]),
|
||||
"dds": (b.shape[2], b.shape[3]),
|
||||
"dsd": (a_shape[2], a_shape[3]),
|
||||
"dds": (b_shape[2], b_shape[3]),
|
||||
}[MODE]
|
||||
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
||||
layout[1, 2, :] = 0
|
||||
layout[1, :, 1] = 0
|
||||
# create data
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1)
|
||||
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1)
|
||||
dc_ref, dc_tri = triton.testing.make_pair(c_shape)
|
||||
# compute [torch]
|
||||
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
|
||||
a_ref = do_mask(a_ref) if is_dsd else a_ref
|
||||
b_ref = do_mask(b_ref) if is_dds else b_ref
|
||||
a_ref.retain_grad()
|
||||
b_ref.retain_grad()
|
||||
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
|
||||
b_ref.transpose(2, 3) if TRANS_B else b_ref)
|
||||
c_ref.backward(dc_ref)
|
||||
c_ref = do_sparsify(c_ref) if is_sdd else c_ref
|
||||
da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
|
||||
db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad
|
||||
# triton result
|
||||
dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
|
||||
a_tri = do_sparsify(a_tri) if is_dsd else a_tri
|
||||
b_tri = do_sparsify(b_tri) if is_dds else b_tri
|
||||
a_tri.retain_grad()
|
||||
b_tri.retain_grad()
|
||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
|
||||
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||
rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest)
|
||||
# torch result
|
||||
ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||
tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||
ta = ta.transpose(2, 3) if TRANS_A else ta
|
||||
tb = tb.transpose(2, 3) if TRANS_B else tb
|
||||
tc = torch.matmul(ta, tb)
|
||||
tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
|
||||
tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
|
||||
c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest)
|
||||
triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest)
|
||||
da_tri = a_tri.grad
|
||||
db_tri = b_tri.grad
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(rc, tc)
|
||||
triton.testing.assert_almost_equal(c_ref, c_tri)
|
||||
triton.testing.assert_almost_equal(da_ref, da_tri)
|
||||
triton.testing.assert_almost_equal(db_ref, db_tri)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
|
||||
@pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792])
|
||||
@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
|
||||
def test_softmax(BLOCK, WIDTH, DTYPE):
|
||||
is_causal = True
|
||||
configs = [
|
||||
(16, 256),
|
||||
(32, 576),
|
||||
(64, 1871),
|
||||
(128, 2511),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_dense", [False, True])
|
||||
@pytest.mark.parametrize("BLOCK, WIDTH", configs)
|
||||
def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
Z, H, M, N = 1, 1, WIDTH, WIDTH
|
||||
scale = 0.4
|
||||
# create inputs
|
||||
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
||||
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
|
||||
at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
|
||||
Z, H, M, N = 2, 3, WIDTH, WIDTH
|
||||
# initialize layout
|
||||
# make sure each row has at least one non-zero element
|
||||
torch.diagonal(layout)[:] = 1
|
||||
torch.diagonal(at_mask)[:] = 1
|
||||
kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
|
||||
kp_mask[:] = 0
|
||||
kp_mask[kp_mask == 1.0] = float("-inf")
|
||||
# triton result
|
||||
op = triton.ops.blocksparse.softmax(layout, BLOCK)
|
||||
tx = triton.testing.sparsify_tensor(x, layout, BLOCK)
|
||||
ty = op(
|
||||
tx,
|
||||
scale=scale,
|
||||
key_padding_mask=kp_mask,
|
||||
key_padding_mask_mode="add",
|
||||
attn_mask=at_mask.to(DTYPE),
|
||||
attn_mask_mode="mul",
|
||||
is_causal=is_causal,
|
||||
)
|
||||
# torch result
|
||||
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
|
||||
# broadcast at_mask to the same shape as rx
|
||||
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
||||
if is_dense:
|
||||
layout[:] = 1
|
||||
else:
|
||||
layout[1, 2, :] = 0
|
||||
layout[1, :, 1] = 0
|
||||
# initialize data
|
||||
a_shape = (Z, H, M, N)
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape)
|
||||
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
|
||||
# compute [torch]
|
||||
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
|
||||
a_ref.retain_grad()
|
||||
at_mask = torch.ones((M, N), device="cuda")
|
||||
if is_causal:
|
||||
at_mask = torch.tril(at_mask)
|
||||
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
|
||||
rx[M == 0] = float("-inf")
|
||||
# rx += kp_mask[:, None, None, :]
|
||||
ry = torch.softmax(rx * scale, -1)
|
||||
ry = triton.testing.sparsify_tensor(ry, layout, BLOCK)
|
||||
M = at_mask[None, None, :, :] + torch.zeros_like(a_ref)
|
||||
a_ref[M == 0] = float("-inf")
|
||||
out_ref = torch.softmax(a_ref * scale, -1)
|
||||
out_ref.backward(dout_ref)
|
||||
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
|
||||
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
|
||||
# compute [triton]
|
||||
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
|
||||
a_tri.retain_grad()
|
||||
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
|
||||
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
|
||||
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
|
||||
out_tri.backward(dout_tri)
|
||||
da_tri = a_tri.grad
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ry, ty)
|
||||
triton.testing.assert_almost_equal(out_tri, out_ref)
|
||||
triton.testing.assert_almost_equal(da_tri, da_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||
@@ -99,14 +130,6 @@ def test_attention_fwd_bwd(
|
||||
qkvs = [
|
||||
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
|
||||
]
|
||||
attn_mask = torch.tril(
|
||||
torch.ones(
|
||||
[n_ctx, n_ctx],
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
),
|
||||
diagonal=0,
|
||||
)
|
||||
|
||||
# Triton:
|
||||
n_blocks = n_ctx // block
|
||||
@@ -115,7 +138,7 @@ def test_attention_fwd_bwd(
|
||||
query.retain_grad()
|
||||
key.retain_grad()
|
||||
value.retain_grad()
|
||||
attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
|
||||
attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale)
|
||||
# ad hoc loss
|
||||
loss = (attn_out ** 2).mean()
|
||||
loss.backward()
|
||||
@@ -123,6 +146,8 @@ def test_attention_fwd_bwd(
|
||||
|
||||
# Torch version:
|
||||
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
|
||||
attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype)
|
||||
attn_mask = torch.tril(attn_mask, diagonal=0)
|
||||
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
|
||||
torch_q.retain_grad()
|
||||
torch_k.retain_grad()
|
||||
@@ -147,7 +172,6 @@ def test_attention_fwd_bwd(
|
||||
def triton_attention(
|
||||
layout,
|
||||
block: int,
|
||||
attn_mask: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@@ -155,12 +179,9 @@ def triton_attention(
|
||||
):
|
||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
|
||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
|
||||
sparse_softmax = triton.ops.blocksparse.softmax(
|
||||
layout,
|
||||
block,
|
||||
)
|
||||
sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
|
||||
|
||||
w = sparse_dot_sdd_nt(query, key)
|
||||
w = sparse_softmax(w, scale=scale, attn_mask=attn_mask, attn_mask_mode="mul")
|
||||
w = sparse_softmax(w, scale=scale, is_causal=True)
|
||||
a = sparse_dot_dsd_nn(w, value)
|
||||
return a
|
||||
|
@@ -242,8 +242,31 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
|
||||
|
||||
|
||||
def dsd_lut(layout, block, step, trans, device):
|
||||
"""
|
||||
Generates the look-up table for incrementing pointers in the DSD/DDS matmul.
|
||||
Example (BLOCK=32, STEP=16)
|
||||
[[1, 0, 0, 1, 0],
|
||||
[0, 1, 1, 0, 1],
|
||||
[1, 0, 1, 0, 0]]
|
||||
|
||||
Then the offsets for A are
|
||||
[0 , 16, 32, 48] <- row 0
|
||||
\\----/ \\----/
|
||||
col=0 col=3
|
||||
[64, 80, 96, 112, 128, 144] <- row 1
|
||||
\\----/ \\----/ \\------/
|
||||
col=1 col=2 col=3
|
||||
[160, 176, 192, 208]
|
||||
which leads to increments table
|
||||
[0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16]
|
||||
|
||||
Because B is dense, the offsets are
|
||||
[0, 16, 96, 112] <- row 0
|
||||
[32, 48, 64, 80] <- row 1
|
||||
[0, 16, 64, 80] <- row 2
|
||||
"""
|
||||
sizes = torch.sum(layout, 2 if trans else 1)
|
||||
head_id, col_id = sizes.nonzero(as_tuple=True)
|
||||
head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True)
|
||||
sizes = sizes.flatten()
|
||||
segments = sizes * step
|
||||
# pointer increments
|
||||
@@ -258,13 +281,6 @@ def dsd_lut(layout, block, step, trans, device):
|
||||
# -------------------------------
|
||||
# dense input pointer increments
|
||||
# -------------------------------
|
||||
# given a list of the indices for the first element of each non-zero block.
|
||||
# For example, for the indices
|
||||
# [32, 80, 128, 256, 288]
|
||||
# we would generate the increments
|
||||
# [32, 48, 48, 128, 32]
|
||||
# ^
|
||||
# index of first element
|
||||
# Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
|
||||
# that is smaller than the block size, so we need to do a bit of extra work
|
||||
# to handle this case
|
||||
@@ -324,116 +340,11 @@ def dsd_lut(layout, block, step, trans, device):
|
||||
# -----------------------------
|
||||
# Dense = Dense x Sparse (DDS)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dds_kernel(
|
||||
A, B, C,
|
||||
stride_za, stride_ha, stride_ma, stride_ka,
|
||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||
DS0, DS1, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
pid_m = tl.program_id(0)
|
||||
pid_n = tl.program_id(1)
|
||||
num_pid_m = tl.num_programs(0)
|
||||
num_pid_n = tl.num_programs(1)
|
||||
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
|
||||
pid_z = tl.program_id(2)
|
||||
header = lut + pid_n * 4
|
||||
offset = tl.load(header + 0)
|
||||
AS1 = tl.load(header + 1)
|
||||
column = tl.load(header + 2)
|
||||
off_h = tl.load(header + 3)
|
||||
pinc = lut + offset
|
||||
# initialize pointers to A (dense)
|
||||
offs_am = pid_m * TILE_M + tl.arange(0, TILE_M)
|
||||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M)
|
||||
start_ak = tl.load(pinc)
|
||||
start_ak = tl.multiple_of(start_ak, 8)
|
||||
offs_ak = start_ak + tl.arange(0, TILE_K)
|
||||
ptrs_a = A + pid_z * stride_za \
|
||||
+ off_h * stride_ha \
|
||||
+ offs_am[:, None] * stride_ma \
|
||||
+ offs_ak[None, :] * stride_ka
|
||||
# initialize pointers to B (sparse)
|
||||
block_id = tl.load(pinc + 1)
|
||||
block_id = tl.multiple_of(block_id, 8)
|
||||
offs_bn = tl.arange(0, TILE_N)
|
||||
offs_bk = tl.arange(0, TILE_K)
|
||||
ptrs_b = B + pid_z * stride_zb \
|
||||
+ block_id * stride_hb \
|
||||
+ offs_bn[None, :] * stride_bn \
|
||||
+ offs_bk[:, None] * stride_bk
|
||||
# ---------------- #
|
||||
# Inner Loop #
|
||||
# ---------------- #
|
||||
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
||||
for k in range(AS1, 0, -TILE_K):
|
||||
a = tl.load(ptrs_a, mask=offs_am[:, None] < DS0)
|
||||
b = tl.load(ptrs_b, mask=True)
|
||||
acc += tl.dot(a, b)
|
||||
pinc += 2
|
||||
inc_a = tl.load(pinc)
|
||||
inc_b = tl.load(pinc + 1)
|
||||
inc_a = tl.multiple_of(inc_a, 8)
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
inc_a = inc_a * stride_ka
|
||||
ptrs_a += inc_a
|
||||
ptrs_b += inc_b
|
||||
# ---------------- #
|
||||
# Epilogue #
|
||||
# ---------------- #
|
||||
c = acc.to(C.dtype.element_ty)
|
||||
# initialize pointers to C (dense)
|
||||
offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
|
||||
offs_cn = column * TILE_N + tl.arange(0, TILE_N)
|
||||
ptrs_c = C + off_h * stride_hc \
|
||||
+ pid_z * stride_zc \
|
||||
+ offs_cm[:, None] * stride_mc \
|
||||
+ offs_cn[None, :] * stride_nc
|
||||
# write back
|
||||
tl.store(ptrs_c, c, mask=offs_cm[:, None] < DS0)
|
||||
# AB = (B^T A^T)^T
|
||||
|
||||
|
||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||
b = b.contiguous()
|
||||
# shapes / dtypes
|
||||
AS0 = a.size(0)
|
||||
AS1 = a.size(1)
|
||||
AS2 = a.size(3 if trans_a else 2)
|
||||
BS2 = block * spdims[1 if trans_b else 2]
|
||||
dtype = a.dtype
|
||||
# output
|
||||
CS0 = AS0
|
||||
CS1 = AS1
|
||||
CS2 = BS2 if trans_c else AS2
|
||||
CS3 = AS2 if trans_c else BS2
|
||||
if out is None:
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
else:
|
||||
assert out.shape == (CS0, CS1, CS2, CS3)
|
||||
c = out
|
||||
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
|
||||
grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0]
|
||||
_dds_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
||||
AS2, BS2, lut,
|
||||
TILE_M=TILE_M, TILE_N=block, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
|
||||
num_warps=4, GROUP_SIZE_M=4,
|
||||
)
|
||||
return c
|
||||
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
|
||||
|
||||
##############
|
||||
# MAIN API #
|
||||
|
@@ -5,230 +5,235 @@ import triton.language as tl
|
||||
|
||||
|
||||
def num_warps(n):
|
||||
if n < 512:
|
||||
if n <= 128:
|
||||
return 1
|
||||
if n <= 256:
|
||||
return 2
|
||||
if n <= 512:
|
||||
return 4
|
||||
if n < 2048:
|
||||
if n <= 4096:
|
||||
return 8
|
||||
return 16
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
|
||||
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax'] * nargs['BLOCK'])})
|
||||
@triton.jit
|
||||
def _forward(
|
||||
X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
||||
TN: tl.constexpr, BLOCK: tl.constexpr, APPLY_SCALE: tl.constexpr, APPLY_RPE: tl.constexpr, APPLY_KP_MASK: tl.constexpr,
|
||||
KP_MASK_MUL: tl.constexpr, APPLY_ATTN_MASK: tl.constexpr, ATTN_MASK_MUL: tl.constexpr,
|
||||
def _blocksparse_softmax_fwd(
|
||||
Out, A, stride_xz, LUT,
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
scale, is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
pidhm = tl.program_id(0)
|
||||
pidz = tl.program_id(1)
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
# create index ranges
|
||||
rxm = pidhm % BLOCK
|
||||
rbm = pidhm // BLOCK
|
||||
rxn = tl.arange(0, TN) % BLOCK
|
||||
rbn = tl.arange(0, TN) // BLOCK
|
||||
hm = h * tl.num_programs(1) + m
|
||||
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
|
||||
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
|
||||
# extract information from LUT
|
||||
header = LUT + rbm * 2
|
||||
header = LUT + (hm // BLOCK_SIZE) * 2
|
||||
size = tl.load(header + 0)
|
||||
offset = tl.load(header + 1)
|
||||
check = rbn < size
|
||||
rbmn = tl.where(check, rbn, size - 1)
|
||||
# block id and column id
|
||||
blockid = tl.load(LUT + offset + rbmn * 4 + 0)
|
||||
columnid = tl.load(LUT + offset + rbmn * 4 + 1)
|
||||
rowid = tl.load(LUT + offset + rbmn * 4 + 2)
|
||||
headid = tl.load(LUT + offset + rbmn * 4 + 3)
|
||||
# pointers to X
|
||||
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
||||
x = tl.load(px, mask=check, other=-float('inf'))
|
||||
x = x.to(tl.float32)
|
||||
# apply scale
|
||||
if APPLY_SCALE:
|
||||
x = x * scale
|
||||
# apply RPE
|
||||
if APPLY_RPE:
|
||||
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
|
||||
rpe = tl.load(prpe, mask=check, other=0)
|
||||
x = x + rpe
|
||||
# apply key-padding mask
|
||||
if APPLY_KP_MASK:
|
||||
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
|
||||
kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
|
||||
if KP_MASK_MUL:
|
||||
kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
|
||||
x = x + kp_m
|
||||
# apply attention mask
|
||||
if APPLY_ATTN_MASK:
|
||||
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
|
||||
attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
|
||||
if ATTN_MASK_MUL:
|
||||
attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
|
||||
x = x + attn_m
|
||||
# pointer offset
|
||||
off_a = z * stride_xz
|
||||
off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx
|
||||
off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx
|
||||
# do not need to read column indices in the dense case
|
||||
if IS_DENSE:
|
||||
ns = tl.arange(0, ROW_SIZE)
|
||||
else:
|
||||
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
|
||||
start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)
|
||||
ns = start_n * BLOCK_SIZE + lane_n
|
||||
# load X
|
||||
mask = block_n < size
|
||||
a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf"))
|
||||
a = a.to(tl.float32)
|
||||
# compute
|
||||
out = a
|
||||
out *= scale
|
||||
# apply relative attention
|
||||
if R is not None:
|
||||
R += z * stride_zr
|
||||
R += h * stride_hr
|
||||
off_lo = (extent - m - 1) + ns
|
||||
mask_lo = (off_lo >= 0) & (off_lo < extent)
|
||||
rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0)
|
||||
out += rel_logits
|
||||
out = out.to(tl.float32)
|
||||
# apply causal mask
|
||||
is_in_upper_triangle = columnid * BLOCK + rxn > rowid * BLOCK + rxm
|
||||
x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.)
|
||||
out = tl.where((ns > m) & is_causal, -float("inf"), out)
|
||||
# computation
|
||||
x = tl.softmax(x)
|
||||
tl.store(px, x, mask=check)
|
||||
out = tl.softmax(out)
|
||||
# write-back
|
||||
tl.store(Out + off_a + lane_n, out, mask=mask)
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
|
||||
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax']) * nargs['BLOCK']})
|
||||
@triton.jit
|
||||
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr):
|
||||
pidhm = tl.program_id(0)
|
||||
pidz = tl.program_id(1)
|
||||
def _blocksparse_softmax_bwd(
|
||||
DA, stride_zdx,
|
||||
DOut, stride_zdout,
|
||||
Out, stride_zout,
|
||||
scale,
|
||||
LUT,
|
||||
DR, extent, stride_zr, stride_hr, stride_er,
|
||||
is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
# create index ranges
|
||||
rxm = pidhm % BLOCK
|
||||
rbm = pidhm // BLOCK
|
||||
rxn = tl.arange(0, TN) % BLOCK
|
||||
rbn = tl.arange(0, TN) // BLOCK
|
||||
# extract information from look-up table
|
||||
header = LUT + rbm * 2
|
||||
hm = h * tl.num_programs(1) + m
|
||||
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
|
||||
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
|
||||
# extract information from LUT
|
||||
header = LUT + (hm // BLOCK_SIZE) * 2
|
||||
size = tl.load(header + 0)
|
||||
offset = tl.load(header + 1)
|
||||
# bounds checking on lut
|
||||
check = rbn < size
|
||||
rbmn = tl.where(check, rbn, size - 1)
|
||||
# initialize pointers to block-sparse input
|
||||
blockid = tl.load(LUT + offset + rbmn * 4)
|
||||
X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
||||
DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
||||
# compute fused softmax backward
|
||||
x = tl.load(X, mask=check, other=0)
|
||||
dx = tl.load(DX, mask=check, other=0)
|
||||
x = x.to(tl.float32)
|
||||
dx = dx.to(tl.float32)
|
||||
y = x * (dx - tl.sum(x * dx, 0)) * scale
|
||||
tl.store(DX, y, mask=check)
|
||||
# row-col offset
|
||||
off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE
|
||||
off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE
|
||||
mask = block_n < size
|
||||
# pointers
|
||||
As = Out + z * stride_zout + off_mn
|
||||
DOuts = DOut + z * stride_zdout + off_mn
|
||||
# do not need to read column indices in the dense case
|
||||
if IS_DENSE:
|
||||
ns = tl.arange(0, ROW_SIZE)
|
||||
else:
|
||||
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
|
||||
start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0)
|
||||
ns = start_n * BLOCK_SIZE + lane_n
|
||||
# load data
|
||||
a = tl.load(As + lane_n, mask=mask, other=0.0)
|
||||
a = a.to(tl.float32)
|
||||
dout = tl.load(DOuts + lane_n, mask=mask, other=0.0)
|
||||
dout = dout.to(tl.float32)
|
||||
# compute
|
||||
da = a * (dout - tl.sum(a * dout, 0))
|
||||
da = tl.where((ns > m) & is_causal, 0., da)
|
||||
# apply relative attention
|
||||
if DR is not None:
|
||||
DR += z * stride_zr
|
||||
DR += h * stride_hr
|
||||
off_lo = (extent - m - 1) + ns
|
||||
mask_lo = (off_lo >= 0) & (off_lo < extent) & mask
|
||||
tl.store(DR + m * extent + off_lo, da, mask=mask_lo)
|
||||
da = da * scale
|
||||
# convert da
|
||||
# write-back
|
||||
DAs = DA + z * stride_zdx + off_mn
|
||||
tl.store(DAs + lane_n, da, mask=mask)
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def make_lut(layout, block, device):
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
sizes = _empty.clone()
|
||||
# sizes along rows
|
||||
sizes = layout.sum(-1).view(-1)
|
||||
for h in range(layout.shape[0]):
|
||||
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
|
||||
total_sizes = sizes * block
|
||||
# offsets in block format
|
||||
offsets = torch.zeros_like(sizes)
|
||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||
# block indices
|
||||
layout_sum = sizes.sum()
|
||||
idx = torch.arange(layout_sum, device=layout.device)
|
||||
layout_nonzero = layout.nonzero(as_tuple=False)
|
||||
head = layout_nonzero[:, 0]
|
||||
rows = layout_nonzero[:, 1]
|
||||
columns = layout_nonzero[:, 2]
|
||||
core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
|
||||
# construct look-up table
|
||||
offsets = offsets * 4 + 2 * sizes.numel()
|
||||
columns = layout.nonzero(as_tuple=False)[:, 2]
|
||||
header = torch.stack((sizes, offsets), dim=1).view(-1)
|
||||
lut = torch.cat((header, core)).type(torch.int32).to(device)
|
||||
return lut, int(sizes.max())
|
||||
lut = torch.cat((header, columns)).type(torch.int32).to(device)
|
||||
return lut, int(total_sizes.max())
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, x, scale, rpe,
|
||||
key_padding_mask, attn_mask,
|
||||
kp_mask_mode, attn_mask_mode,
|
||||
is_causal,
|
||||
spdims, block, lut, maxlut
|
||||
ctx, a, scale, rel_logits, is_causal,
|
||||
spdims, block, lut, maxlut, is_dense
|
||||
):
|
||||
apply_scale = False if scale == 1.0 else True
|
||||
# handle None rpe
|
||||
if rpe is None:
|
||||
apply_rpe = False
|
||||
stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0
|
||||
rpe = torch.empty(0, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
apply_rpe = True
|
||||
stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)
|
||||
# handle None key_padding_mask
|
||||
if key_padding_mask is None:
|
||||
apply_kp_mask = False
|
||||
stride_zkpm = 0
|
||||
key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
apply_kp_mask = True
|
||||
stride_zkpm = key_padding_mask.stride(0)
|
||||
# handle None attention_mask
|
||||
if attn_mask is None:
|
||||
apply_attn_mask = False
|
||||
stride_zattnm = 0
|
||||
attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
apply_attn_mask = True
|
||||
stride_zattnm = attn_mask.stride(0)
|
||||
# run kernel
|
||||
M = x.shape[0]
|
||||
grid = [spdims[0] * spdims[1] * block, M]
|
||||
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),
|
||||
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
||||
BLOCK=block,
|
||||
APPLY_SCALE=apply_scale,
|
||||
APPLY_RPE=apply_rpe,
|
||||
APPLY_KP_MASK=apply_kp_mask,
|
||||
APPLY_ATTN_MASK=apply_attn_mask,
|
||||
KP_MASK_MUL=(kp_mask_mode == 'mul'),
|
||||
ATTN_MASK_MUL=(attn_mask_mode == 'mul'))
|
||||
if scale is not None and isinstance(scale, torch.Tensor):
|
||||
assert scale.device.type == "cpu"
|
||||
scale = scale.item()
|
||||
M = a.shape[0]
|
||||
grid = [spdims[0], spdims[1] * block, M]
|
||||
rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape
|
||||
rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride()
|
||||
# enqueue kernel
|
||||
out = torch.empty_like(a)
|
||||
_blocksparse_softmax_fwd[grid](
|
||||
out, a, a.stride(0), lut,
|
||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
||||
scale,
|
||||
is_causal,
|
||||
BLOCK_SIZE=block,
|
||||
ROW_SIZE=triton.next_power_of_2(maxlut),
|
||||
IS_DENSE=is_dense,
|
||||
num_warps=num_warps(maxlut)
|
||||
)
|
||||
# save to context
|
||||
ctx.mark_dirty(x)
|
||||
ctx.save_for_backward(x, lut)
|
||||
# ctx.mark_dirty(x)
|
||||
ctx.save_for_backward(out, lut)
|
||||
ctx.spdims = spdims
|
||||
ctx.block = block
|
||||
ctx.maxlut = maxlut
|
||||
ctx.scale = scale
|
||||
ctx.apply_scale = apply_scale
|
||||
ctx.apply_rpe = apply_rpe
|
||||
ctx.apply_kp_mask = apply_kp_mask
|
||||
ctx.apply_attn_mask = apply_attn_mask
|
||||
ctx.kp_mask_mode = kp_mask_mode
|
||||
ctx.attn_mask_mode = attn_mask_mode
|
||||
return x
|
||||
ctx.rel_shape = rel_shape
|
||||
ctx.rel_strides = rel_strides
|
||||
ctx.rel_dtype = a.dtype
|
||||
ctx.is_dense = is_dense
|
||||
ctx.is_causal = is_causal
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dx):
|
||||
def backward(ctx, dout):
|
||||
# retrieve from context
|
||||
x, lut = ctx.saved_tensors
|
||||
out, lut = ctx.saved_tensors
|
||||
# relative logits gradients
|
||||
dr = None
|
||||
if ctx.needs_input_grad[3]:
|
||||
dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device)
|
||||
# run kernel
|
||||
M = x.shape[0]
|
||||
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
|
||||
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
|
||||
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
M = out.shape[0]
|
||||
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
|
||||
da = torch.empty_like(dout)
|
||||
_blocksparse_softmax_bwd[grid](
|
||||
da, da.stride(0),
|
||||
dout, dout.stride(0),
|
||||
out, out.stride(0),
|
||||
ctx.scale,
|
||||
lut,
|
||||
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
|
||||
ctx.is_causal,
|
||||
BLOCK_SIZE=ctx.block,
|
||||
ROW_SIZE=triton.next_power_of_2(ctx.maxlut),
|
||||
IS_DENSE=ctx.is_dense,
|
||||
num_warps=num_warps(ctx.maxlut)
|
||||
)
|
||||
return (da, None, None, dr, None,
|
||||
None, None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None
|
||||
)
|
||||
|
||||
|
||||
class softmax:
|
||||
|
||||
def make_lut(self, device):
|
||||
key = (device, )
|
||||
if key not in self.lut_cache:
|
||||
self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
|
||||
return self.lut_cache[key]
|
||||
|
||||
def __init__(self, layout, block):
|
||||
def __init__(self, layout, block, device, is_dense=False):
|
||||
self.spdims = layout.shape
|
||||
self.layout = layout
|
||||
self.block = block
|
||||
self.lut_cache = dict()
|
||||
self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device)
|
||||
self.is_dense = is_dense
|
||||
|
||||
def __call__(
|
||||
self, x, scale=1., rpe=None,
|
||||
key_padding_mask=None, attn_mask=None,
|
||||
key_padding_mask_mode='add', attn_mask_mode='add',
|
||||
is_causal=False
|
||||
):
|
||||
if rpe is not None and rpe.dtype != x.dtype:
|
||||
raise ValueError('relative position embedding must be %s' % x.dtype)
|
||||
if attn_mask is not None and attn_mask.dtype != x.dtype:
|
||||
raise ValueError('Attention mask must be %s' % x.dtype)
|
||||
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
|
||||
raise ValueError('Key padding mask must be %s' % x.dtype)
|
||||
lut, maxlut = self.make_lut(x.device)
|
||||
x = _softmax.apply(
|
||||
x, scale, rpe,
|
||||
key_padding_mask, attn_mask,
|
||||
key_padding_mask_mode, attn_mask_mode,
|
||||
is_causal,
|
||||
self.spdims, self.block,
|
||||
lut, maxlut
|
||||
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
|
||||
if rel_logits is not None and rel_logits.dtype != a.dtype:
|
||||
raise ValueError("relative position embedding must be %s" % a.dtype)
|
||||
a = _softmax.apply(
|
||||
a, scale, rel_logits, is_causal,
|
||||
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
|
||||
)
|
||||
return x
|
||||
return a
|
||||
|
@@ -32,6 +32,19 @@ def sparsify_tensor(x, mask, block):
|
||||
return ret
|
||||
|
||||
|
||||
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None):
|
||||
if data is None:
|
||||
data = torch.randn(shape, dtype=torch.float32, device=device)
|
||||
ref_ret = data
|
||||
ref_ret = ref_ret * alpha + beta
|
||||
ref_ret = ref_ret.half().float()
|
||||
if trans:
|
||||
ref_ret = ref_ret.t().requires_grad_()
|
||||
ref_ret = ref_ret.detach().requires_grad_()
|
||||
tri_ret = ref_ret.clone().detach().requires_grad_()
|
||||
return ref_ret, tri_ret
|
||||
|
||||
|
||||
def cutlass_matmul(a, b):
|
||||
if _cutlass is None:
|
||||
raise RuntimeError("Cannot find cutlass library")
|
||||
|
Reference in New Issue
Block a user