[OPS][BLOCKSPARSE] Improved robustness, clarity and performance (#450)

* dds layout now internally re-uses dsd code path for increased code 
* at_mask and kp_mask related things are now dropped from the softmax API. I couldn't think of any case where it was needed beyond is_causal. And if there is any, we should probably find a way to get it implemented statically so that users don't have to materialize masks.
 * fixed bug in blocksparse matmul that caused troubles when layout had a full row/col of zeros
 * blocksparse softmax now no longer modifies any data in-place
 * blocksparse softmax now takes an is_dense arguments that provides better performance. Passing is_dense=True, is_causal=True is the best way to achieve triangular attention.
  * unit tests now test backward pass
This commit is contained in:
Philippe Tillet
2022-02-06 18:00:45 -08:00
committed by GitHub
parent 69ff52ea1f
commit 5a8a544d10
4 changed files with 311 additions and 361 deletions

View File

@@ -10,77 +10,108 @@ import triton
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
@pytest.mark.parametrize("DTYPE", [torch.float16])
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
# set seed
torch.random.manual_seed(0)
seed = 0
torch.manual_seed(seed)
is_sdd = MODE == "sdd"
is_dsd = MODE == "dsd"
is_dds = MODE == "dds"
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
# create inputs
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
# create op
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
c_shape = (Z, H, M, N)
shape = {
"sdd": (M, N),
"dsd": (a.shape[2], a.shape[3]),
"dds": (b.shape[2], b.shape[3]),
"dsd": (a_shape[2], a_shape[3]),
"dds": (b_shape[2], b_shape[3]),
}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
layout[1, 2, :] = 0
layout[1, :, 1] = 0
# create data
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1)
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1)
dc_ref, dc_tri = triton.testing.make_pair(c_shape)
# compute [torch]
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
a_ref = do_mask(a_ref) if is_dsd else a_ref
b_ref = do_mask(b_ref) if is_dds else b_ref
a_ref.retain_grad()
b_ref.retain_grad()
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
b_ref.transpose(2, 3) if TRANS_B else b_ref)
c_ref.backward(dc_ref)
c_ref = do_sparsify(c_ref) if is_sdd else c_ref
da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad
# triton result
dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
a_tri = do_sparsify(a_tri) if is_dsd else a_tri
b_tri = do_sparsify(b_tri) if is_dds else b_tri
a_tri.retain_grad()
b_tri.retain_grad()
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest)
# torch result
ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a
tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b
ta = ta.transpose(2, 3) if TRANS_A else ta
tb = tb.transpose(2, 3) if TRANS_B else tb
tc = torch.matmul(ta, tb)
tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest)
triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest)
da_tri = a_tri.grad
db_tri = b_tri.grad
# compare
triton.testing.assert_almost_equal(rc, tc)
triton.testing.assert_almost_equal(c_ref, c_tri)
triton.testing.assert_almost_equal(da_ref, da_tri)
triton.testing.assert_almost_equal(db_ref, db_tri)
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
@pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792])
@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
def test_softmax(BLOCK, WIDTH, DTYPE):
is_causal = True
configs = [
(16, 256),
(32, 576),
(64, 1871),
(128, 2511),
]
@pytest.mark.parametrize("is_dense", [False, True])
@pytest.mark.parametrize("BLOCK, WIDTH", configs)
def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
# set seed
torch.random.manual_seed(0)
Z, H, M, N = 1, 1, WIDTH, WIDTH
scale = 0.4
# create inputs
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
Z, H, M, N = 2, 3, WIDTH, WIDTH
# initialize layout
# make sure each row has at least one non-zero element
torch.diagonal(layout)[:] = 1
torch.diagonal(at_mask)[:] = 1
kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
kp_mask[:] = 0
kp_mask[kp_mask == 1.0] = float("-inf")
# triton result
op = triton.ops.blocksparse.softmax(layout, BLOCK)
tx = triton.testing.sparsify_tensor(x, layout, BLOCK)
ty = op(
tx,
scale=scale,
key_padding_mask=kp_mask,
key_padding_mask_mode="add",
attn_mask=at_mask.to(DTYPE),
attn_mask_mode="mul",
is_causal=is_causal,
)
# torch result
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
# broadcast at_mask to the same shape as rx
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
if is_dense:
layout[:] = 1
else:
layout[1, 2, :] = 0
layout[1, :, 1] = 0
# initialize data
a_shape = (Z, H, M, N)
a_ref, a_tri = triton.testing.make_pair(a_shape)
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
# compute [torch]
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
a_ref.retain_grad()
at_mask = torch.ones((M, N), device="cuda")
if is_causal:
at_mask = torch.tril(at_mask)
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
rx[M == 0] = float("-inf")
# rx += kp_mask[:, None, None, :]
ry = torch.softmax(rx * scale, -1)
ry = triton.testing.sparsify_tensor(ry, layout, BLOCK)
M = at_mask[None, None, :, :] + torch.zeros_like(a_ref)
a_ref[M == 0] = float("-inf")
out_ref = torch.softmax(a_ref * scale, -1)
out_ref.backward(dout_ref)
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
# compute [triton]
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
a_tri.retain_grad()
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
out_tri.backward(dout_tri)
da_tri = a_tri.grad
# compare
triton.testing.assert_almost_equal(ry, ty)
triton.testing.assert_almost_equal(out_tri, out_ref)
triton.testing.assert_almost_equal(da_tri, da_ref)
@pytest.mark.parametrize("block", [16, 32, 64])
@@ -99,14 +130,6 @@ def test_attention_fwd_bwd(
qkvs = [
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
]
attn_mask = torch.tril(
torch.ones(
[n_ctx, n_ctx],
device="cuda",
dtype=dtype,
),
diagonal=0,
)
# Triton:
n_blocks = n_ctx // block
@@ -115,7 +138,7 @@ def test_attention_fwd_bwd(
query.retain_grad()
key.retain_grad()
value.retain_grad()
attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale)
# ad hoc loss
loss = (attn_out ** 2).mean()
loss.backward()
@@ -123,6 +146,8 @@ def test_attention_fwd_bwd(
# Torch version:
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype)
attn_mask = torch.tril(attn_mask, diagonal=0)
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
torch_q.retain_grad()
torch_k.retain_grad()
@@ -147,7 +172,6 @@ def test_attention_fwd_bwd(
def triton_attention(
layout,
block: int,
attn_mask: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
@@ -155,12 +179,9 @@ def triton_attention(
):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
sparse_softmax = triton.ops.blocksparse.softmax(
layout,
block,
)
sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
w = sparse_dot_sdd_nt(query, key)
w = sparse_softmax(w, scale=scale, attn_mask=attn_mask, attn_mask_mode="mul")
w = sparse_softmax(w, scale=scale, is_causal=True)
a = sparse_dot_dsd_nn(w, value)
return a

View File

@@ -242,8 +242,31 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
def dsd_lut(layout, block, step, trans, device):
"""
Generates the look-up table for incrementing pointers in the DSD/DDS matmul.
Example (BLOCK=32, STEP=16)
[[1, 0, 0, 1, 0],
[0, 1, 1, 0, 1],
[1, 0, 1, 0, 0]]
Then the offsets for A are
[0 , 16, 32, 48] <- row 0
\\----/ \\----/
col=0 col=3
[64, 80, 96, 112, 128, 144] <- row 1
\\----/ \\----/ \\------/
col=1 col=2 col=3
[160, 176, 192, 208]
which leads to increments table
[0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16]
Because B is dense, the offsets are
[0, 16, 96, 112] <- row 0
[32, 48, 64, 80] <- row 1
[0, 16, 64, 80] <- row 2
"""
sizes = torch.sum(layout, 2 if trans else 1)
head_id, col_id = sizes.nonzero(as_tuple=True)
head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True)
sizes = sizes.flatten()
segments = sizes * step
# pointer increments
@@ -258,13 +281,6 @@ def dsd_lut(layout, block, step, trans, device):
# -------------------------------
# dense input pointer increments
# -------------------------------
# given a list of the indices for the first element of each non-zero block.
# For example, for the indices
# [32, 80, 128, 256, 288]
# we would generate the increments
# [32, 48, 48, 128, 32]
# ^
# index of first element
# Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
# that is smaller than the block size, so we need to do a bit of extra work
# to handle this case
@@ -324,116 +340,11 @@ def dsd_lut(layout, block, step, trans, device):
# -----------------------------
# Dense = Dense x Sparse (DDS)
# -----------------------------
@triton.jit
def _dds_kernel(
A, B, C,
stride_za, stride_ha, stride_ma, stride_ka,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_mc, stride_nc,
DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
):
# ------------ #
# - Prologue - #
# ------------ #
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1)
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
pid_z = tl.program_id(2)
header = lut + pid_n * 4
offset = tl.load(header + 0)
AS1 = tl.load(header + 1)
column = tl.load(header + 2)
off_h = tl.load(header + 3)
pinc = lut + offset
# initialize pointers to A (dense)
offs_am = pid_m * TILE_M + tl.arange(0, TILE_M)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M)
start_ak = tl.load(pinc)
start_ak = tl.multiple_of(start_ak, 8)
offs_ak = start_ak + tl.arange(0, TILE_K)
ptrs_a = A + pid_z * stride_za \
+ off_h * stride_ha \
+ offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ka
# initialize pointers to B (sparse)
block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8)
offs_bn = tl.arange(0, TILE_N)
offs_bk = tl.arange(0, TILE_K)
ptrs_b = B + pid_z * stride_zb \
+ block_id * stride_hb \
+ offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk
# ---------------- #
# Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(AS1, 0, -TILE_K):
a = tl.load(ptrs_a, mask=offs_am[:, None] < DS0)
b = tl.load(ptrs_b, mask=True)
acc += tl.dot(a, b)
pinc += 2
inc_a = tl.load(pinc)
inc_b = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.multiple_of(inc_b, 8)
inc_a = inc_a * stride_ka
ptrs_a += inc_a
ptrs_b += inc_b
# ---------------- #
# Epilogue #
# ---------------- #
c = acc.to(C.dtype.element_ty)
# initialize pointers to C (dense)
offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
offs_cn = column * TILE_N + tl.arange(0, TILE_N)
ptrs_c = C + off_h * stride_hc \
+ pid_z * stride_zc \
+ offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc
# write back
tl.store(ptrs_c, c, mask=offs_cm[:, None] < DS0)
# AB = (B^T A^T)^T
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
b = b.contiguous()
# shapes / dtypes
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype
# output
CS0 = AS0
CS1 = AS1
CS2 = BS2 if trans_c else AS2
CS3 = AS2 if trans_c else BS2
if out is None:
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
else:
assert out.shape == (CS0, CS1, CS2, CS3)
c = out
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0]
_dds_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
AS2, BS2, lut,
TILE_M=TILE_M, TILE_N=block, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4,
)
return c
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
##############
# MAIN API #

View File

@@ -5,230 +5,235 @@ import triton.language as tl
def num_warps(n):
if n < 512:
if n <= 128:
return 1
if n <= 256:
return 2
if n <= 512:
return 4
if n < 2048:
if n <= 4096:
return 8
return 16
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax'] * nargs['BLOCK'])})
@triton.jit
def _forward(
X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
TN: tl.constexpr, BLOCK: tl.constexpr, APPLY_SCALE: tl.constexpr, APPLY_RPE: tl.constexpr, APPLY_KP_MASK: tl.constexpr,
KP_MASK_MUL: tl.constexpr, APPLY_ATTN_MASK: tl.constexpr, ATTN_MASK_MUL: tl.constexpr,
def _blocksparse_softmax_fwd(
Out, A, stride_xz, LUT,
R, extent, stride_zr, stride_hr, # relative attention
scale, is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_DENSE: tl.constexpr,
):
pidhm = tl.program_id(0)
pidz = tl.program_id(1)
h = tl.program_id(0)
m = tl.program_id(1)
z = tl.program_id(2)
# create index ranges
rxm = pidhm % BLOCK
rbm = pidhm // BLOCK
rxn = tl.arange(0, TN) % BLOCK
rbn = tl.arange(0, TN) // BLOCK
hm = h * tl.num_programs(1) + m
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
# extract information from LUT
header = LUT + rbm * 2
header = LUT + (hm // BLOCK_SIZE) * 2
size = tl.load(header + 0)
offset = tl.load(header + 1)
check = rbn < size
rbmn = tl.where(check, rbn, size - 1)
# block id and column id
blockid = tl.load(LUT + offset + rbmn * 4 + 0)
columnid = tl.load(LUT + offset + rbmn * 4 + 1)
rowid = tl.load(LUT + offset + rbmn * 4 + 2)
headid = tl.load(LUT + offset + rbmn * 4 + 3)
# pointers to X
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
x = tl.load(px, mask=check, other=-float('inf'))
x = x.to(tl.float32)
# apply scale
if APPLY_SCALE:
x = x * scale
# apply RPE
if APPLY_RPE:
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
rpe = tl.load(prpe, mask=check, other=0)
x = x + rpe
# apply key-padding mask
if APPLY_KP_MASK:
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
if KP_MASK_MUL:
kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
x = x + kp_m
# apply attention mask
if APPLY_ATTN_MASK:
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
if ATTN_MASK_MUL:
attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
x = x + attn_m
# pointer offset
off_a = z * stride_xz
off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx
off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx
# do not need to read column indices in the dense case
if IS_DENSE:
ns = tl.arange(0, ROW_SIZE)
else:
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)
ns = start_n * BLOCK_SIZE + lane_n
# load X
mask = block_n < size
a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf"))
a = a.to(tl.float32)
# compute
out = a
out *= scale
# apply relative attention
if R is not None:
R += z * stride_zr
R += h * stride_hr
off_lo = (extent - m - 1) + ns
mask_lo = (off_lo >= 0) & (off_lo < extent)
rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0)
out += rel_logits
out = out.to(tl.float32)
# apply causal mask
is_in_upper_triangle = columnid * BLOCK + rxn > rowid * BLOCK + rxm
x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.)
out = tl.where((ns > m) & is_causal, -float("inf"), out)
# computation
x = tl.softmax(x)
tl.store(px, x, mask=check)
out = tl.softmax(out)
# write-back
tl.store(Out + off_a + lane_n, out, mask=mask)
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax']) * nargs['BLOCK']})
@triton.jit
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr):
pidhm = tl.program_id(0)
pidz = tl.program_id(1)
def _blocksparse_softmax_bwd(
DA, stride_zdx,
DOut, stride_zdout,
Out, stride_zout,
scale,
LUT,
DR, extent, stride_zr, stride_hr, stride_er,
is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_DENSE: tl.constexpr,
):
h = tl.program_id(0)
m = tl.program_id(1)
z = tl.program_id(2)
# create index ranges
rxm = pidhm % BLOCK
rbm = pidhm // BLOCK
rxn = tl.arange(0, TN) % BLOCK
rbn = tl.arange(0, TN) // BLOCK
# extract information from look-up table
header = LUT + rbm * 2
hm = h * tl.num_programs(1) + m
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
# extract information from LUT
header = LUT + (hm // BLOCK_SIZE) * 2
size = tl.load(header + 0)
offset = tl.load(header + 1)
# bounds checking on lut
check = rbn < size
rbmn = tl.where(check, rbn, size - 1)
# initialize pointers to block-sparse input
blockid = tl.load(LUT + offset + rbmn * 4)
X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
# compute fused softmax backward
x = tl.load(X, mask=check, other=0)
dx = tl.load(DX, mask=check, other=0)
x = x.to(tl.float32)
dx = dx.to(tl.float32)
y = x * (dx - tl.sum(x * dx, 0)) * scale
tl.store(DX, y, mask=check)
# row-col offset
off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE
off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE
mask = block_n < size
# pointers
As = Out + z * stride_zout + off_mn
DOuts = DOut + z * stride_zdout + off_mn
# do not need to read column indices in the dense case
if IS_DENSE:
ns = tl.arange(0, ROW_SIZE)
else:
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0)
ns = start_n * BLOCK_SIZE + lane_n
# load data
a = tl.load(As + lane_n, mask=mask, other=0.0)
a = a.to(tl.float32)
dout = tl.load(DOuts + lane_n, mask=mask, other=0.0)
dout = dout.to(tl.float32)
# compute
da = a * (dout - tl.sum(a * dout, 0))
da = tl.where((ns > m) & is_causal, 0., da)
# apply relative attention
if DR is not None:
DR += z * stride_zr
DR += h * stride_hr
off_lo = (extent - m - 1) + ns
mask_lo = (off_lo >= 0) & (off_lo < extent) & mask
tl.store(DR + m * extent + off_lo, da, mask=mask_lo)
da = da * scale
# convert da
# write-back
DAs = DA + z * stride_zdx + off_mn
tl.store(DAs + lane_n, da, mask=mask)
class _softmax(torch.autograd.Function):
@staticmethod
def make_lut(layout, block, device):
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
sizes = _empty.clone()
# sizes along rows
sizes = layout.sum(-1).view(-1)
for h in range(layout.shape[0]):
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
total_sizes = sizes * block
# offsets in block format
offsets = torch.zeros_like(sizes)
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
# block indices
layout_sum = sizes.sum()
idx = torch.arange(layout_sum, device=layout.device)
layout_nonzero = layout.nonzero(as_tuple=False)
head = layout_nonzero[:, 0]
rows = layout_nonzero[:, 1]
columns = layout_nonzero[:, 2]
core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
# construct look-up table
offsets = offsets * 4 + 2 * sizes.numel()
columns = layout.nonzero(as_tuple=False)[:, 2]
header = torch.stack((sizes, offsets), dim=1).view(-1)
lut = torch.cat((header, core)).type(torch.int32).to(device)
return lut, int(sizes.max())
lut = torch.cat((header, columns)).type(torch.int32).to(device)
return lut, int(total_sizes.max())
@staticmethod
def forward(
ctx, x, scale, rpe,
key_padding_mask, attn_mask,
kp_mask_mode, attn_mask_mode,
is_causal,
spdims, block, lut, maxlut
ctx, a, scale, rel_logits, is_causal,
spdims, block, lut, maxlut, is_dense
):
apply_scale = False if scale == 1.0 else True
# handle None rpe
if rpe is None:
apply_rpe = False
stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0
rpe = torch.empty(0, dtype=x.dtype, device=x.device)
else:
apply_rpe = True
stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)
# handle None key_padding_mask
if key_padding_mask is None:
apply_kp_mask = False
stride_zkpm = 0
key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)
else:
apply_kp_mask = True
stride_zkpm = key_padding_mask.stride(0)
# handle None attention_mask
if attn_mask is None:
apply_attn_mask = False
stride_zattnm = 0
attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)
else:
apply_attn_mask = True
stride_zattnm = attn_mask.stride(0)
# run kernel
M = x.shape[0]
grid = [spdims[0] * spdims[1] * block, M]
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
BLOCK=block,
APPLY_SCALE=apply_scale,
APPLY_RPE=apply_rpe,
APPLY_KP_MASK=apply_kp_mask,
APPLY_ATTN_MASK=apply_attn_mask,
KP_MASK_MUL=(kp_mask_mode == 'mul'),
ATTN_MASK_MUL=(attn_mask_mode == 'mul'))
if scale is not None and isinstance(scale, torch.Tensor):
assert scale.device.type == "cpu"
scale = scale.item()
M = a.shape[0]
grid = [spdims[0], spdims[1] * block, M]
rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape
rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride()
# enqueue kernel
out = torch.empty_like(a)
_blocksparse_softmax_fwd[grid](
out, a, a.stride(0), lut,
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
scale,
is_causal,
BLOCK_SIZE=block,
ROW_SIZE=triton.next_power_of_2(maxlut),
IS_DENSE=is_dense,
num_warps=num_warps(maxlut)
)
# save to context
ctx.mark_dirty(x)
ctx.save_for_backward(x, lut)
# ctx.mark_dirty(x)
ctx.save_for_backward(out, lut)
ctx.spdims = spdims
ctx.block = block
ctx.maxlut = maxlut
ctx.scale = scale
ctx.apply_scale = apply_scale
ctx.apply_rpe = apply_rpe
ctx.apply_kp_mask = apply_kp_mask
ctx.apply_attn_mask = apply_attn_mask
ctx.kp_mask_mode = kp_mask_mode
ctx.attn_mask_mode = attn_mask_mode
return x
ctx.rel_shape = rel_shape
ctx.rel_strides = rel_strides
ctx.rel_dtype = a.dtype
ctx.is_dense = is_dense
ctx.is_causal = is_causal
return out
@staticmethod
def backward(ctx, dx):
def backward(ctx, dout):
# retrieve from context
x, lut = ctx.saved_tensors
out, lut = ctx.saved_tensors
# relative logits gradients
dr = None
if ctx.needs_input_grad[3]:
dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device)
# run kernel
M = x.shape[0]
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
M = out.shape[0]
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
da = torch.empty_like(dout)
_blocksparse_softmax_bwd[grid](
da, da.stride(0),
dout, dout.stride(0),
out, out.stride(0),
ctx.scale,
lut,
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
ctx.is_causal,
BLOCK_SIZE=ctx.block,
ROW_SIZE=triton.next_power_of_2(ctx.maxlut),
IS_DENSE=ctx.is_dense,
num_warps=num_warps(ctx.maxlut)
)
return (da, None, None, dr, None,
None, None, None, None, None,
None,
None, None, None,
None,
None, None, None
)
class softmax:
def make_lut(self, device):
key = (device, )
if key not in self.lut_cache:
self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
return self.lut_cache[key]
def __init__(self, layout, block):
def __init__(self, layout, block, device, is_dense=False):
self.spdims = layout.shape
self.layout = layout
self.block = block
self.lut_cache = dict()
self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device)
self.is_dense = is_dense
def __call__(
self, x, scale=1., rpe=None,
key_padding_mask=None, attn_mask=None,
key_padding_mask_mode='add', attn_mask_mode='add',
is_causal=False
):
if rpe is not None and rpe.dtype != x.dtype:
raise ValueError('relative position embedding must be %s' % x.dtype)
if attn_mask is not None and attn_mask.dtype != x.dtype:
raise ValueError('Attention mask must be %s' % x.dtype)
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
raise ValueError('Key padding mask must be %s' % x.dtype)
lut, maxlut = self.make_lut(x.device)
x = _softmax.apply(
x, scale, rpe,
key_padding_mask, attn_mask,
key_padding_mask_mode, attn_mask_mode,
is_causal,
self.spdims, self.block,
lut, maxlut
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
if rel_logits is not None and rel_logits.dtype != a.dtype:
raise ValueError("relative position embedding must be %s" % a.dtype)
a = _softmax.apply(
a, scale, rel_logits, is_causal,
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
)
return x
return a

View File

@@ -32,6 +32,19 @@ def sparsify_tensor(x, mask, block):
return ret
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None):
if data is None:
data = torch.randn(shape, dtype=torch.float32, device=device)
ref_ret = data
ref_ret = ref_ret * alpha + beta
ref_ret = ref_ret.half().float()
if trans:
ref_ret = ref_ret.t().requires_grad_()
ref_ret = ref_ret.detach().requires_grad_()
tri_ret = ref_ret.clone().detach().requires_grad_()
return ref_ret, tri_ret
def cutlass_matmul(a, b):
if _cutlass is None:
raise RuntimeError("Cannot find cutlass library")