[OPS][BLOCKSPARSE] Improved robustness, clarity and performance (#450)
* dds layout now internally re-uses dsd code path for increased code * at_mask and kp_mask related things are now dropped from the softmax API. I couldn't think of any case where it was needed beyond is_causal. And if there is any, we should probably find a way to get it implemented statically so that users don't have to materialize masks. * fixed bug in blocksparse matmul that caused troubles when layout had a full row/col of zeros * blocksparse softmax now no longer modifies any data in-place * blocksparse softmax now takes an is_dense arguments that provides better performance. Passing is_dense=True, is_causal=True is the best way to achieve triangular attention. * unit tests now test backward pass
This commit is contained in:
@@ -10,77 +10,108 @@ import triton
|
|||||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
|
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
|
||||||
@pytest.mark.parametrize("DTYPE", [torch.float16])
|
@pytest.mark.parametrize("DTYPE", [torch.float16])
|
||||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
|
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
|
||||||
# set seed
|
seed = 0
|
||||||
torch.random.manual_seed(0)
|
torch.manual_seed(seed)
|
||||||
|
is_sdd = MODE == "sdd"
|
||||||
|
is_dsd = MODE == "dsd"
|
||||||
|
is_dds = MODE == "dds"
|
||||||
|
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
|
||||||
|
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
|
||||||
# create inputs
|
# create inputs
|
||||||
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
|
# create op
|
||||||
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
|
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
|
||||||
|
b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
|
||||||
|
c_shape = (Z, H, M, N)
|
||||||
shape = {
|
shape = {
|
||||||
"sdd": (M, N),
|
"sdd": (M, N),
|
||||||
"dsd": (a.shape[2], a.shape[3]),
|
"dsd": (a_shape[2], a_shape[3]),
|
||||||
"dds": (b.shape[2], b.shape[3]),
|
"dds": (b_shape[2], b_shape[3]),
|
||||||
}[MODE]
|
}[MODE]
|
||||||
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
||||||
|
layout[1, 2, :] = 0
|
||||||
|
layout[1, :, 1] = 0
|
||||||
|
# create data
|
||||||
|
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1)
|
||||||
|
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1)
|
||||||
|
dc_ref, dc_tri = triton.testing.make_pair(c_shape)
|
||||||
|
# compute [torch]
|
||||||
|
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
|
||||||
|
a_ref = do_mask(a_ref) if is_dsd else a_ref
|
||||||
|
b_ref = do_mask(b_ref) if is_dds else b_ref
|
||||||
|
a_ref.retain_grad()
|
||||||
|
b_ref.retain_grad()
|
||||||
|
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
|
||||||
|
b_ref.transpose(2, 3) if TRANS_B else b_ref)
|
||||||
|
c_ref.backward(dc_ref)
|
||||||
|
c_ref = do_sparsify(c_ref) if is_sdd else c_ref
|
||||||
|
da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
|
||||||
|
db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad
|
||||||
# triton result
|
# triton result
|
||||||
|
dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
|
||||||
|
a_tri = do_sparsify(a_tri) if is_dsd else a_tri
|
||||||
|
b_tri = do_sparsify(b_tri) if is_dds else b_tri
|
||||||
|
a_tri.retain_grad()
|
||||||
|
b_tri.retain_grad()
|
||||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
|
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
|
||||||
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest)
|
||||||
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest)
|
||||||
rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest)
|
da_tri = a_tri.grad
|
||||||
# torch result
|
db_tri = b_tri.grad
|
||||||
ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
|
||||||
tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
|
||||||
ta = ta.transpose(2, 3) if TRANS_A else ta
|
|
||||||
tb = tb.transpose(2, 3) if TRANS_B else tb
|
|
||||||
tc = torch.matmul(ta, tb)
|
|
||||||
tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
|
|
||||||
tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
|
|
||||||
# compare
|
# compare
|
||||||
triton.testing.assert_almost_equal(rc, tc)
|
triton.testing.assert_almost_equal(c_ref, c_tri)
|
||||||
|
triton.testing.assert_almost_equal(da_ref, da_tri)
|
||||||
|
triton.testing.assert_almost_equal(db_ref, db_tri)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
|
configs = [
|
||||||
@pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792])
|
(16, 256),
|
||||||
@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
|
(32, 576),
|
||||||
def test_softmax(BLOCK, WIDTH, DTYPE):
|
(64, 1871),
|
||||||
is_causal = True
|
(128, 2511),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("is_dense", [False, True])
|
||||||
|
@pytest.mark.parametrize("BLOCK, WIDTH", configs)
|
||||||
|
def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
|
||||||
# set seed
|
# set seed
|
||||||
torch.random.manual_seed(0)
|
torch.random.manual_seed(0)
|
||||||
Z, H, M, N = 1, 1, WIDTH, WIDTH
|
Z, H, M, N = 2, 3, WIDTH, WIDTH
|
||||||
scale = 0.4
|
# initialize layout
|
||||||
# create inputs
|
|
||||||
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
|
||||||
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
|
|
||||||
at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
|
|
||||||
# make sure each row has at least one non-zero element
|
# make sure each row has at least one non-zero element
|
||||||
torch.diagonal(layout)[:] = 1
|
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
||||||
torch.diagonal(at_mask)[:] = 1
|
if is_dense:
|
||||||
kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
|
layout[:] = 1
|
||||||
kp_mask[:] = 0
|
else:
|
||||||
kp_mask[kp_mask == 1.0] = float("-inf")
|
layout[1, 2, :] = 0
|
||||||
# triton result
|
layout[1, :, 1] = 0
|
||||||
op = triton.ops.blocksparse.softmax(layout, BLOCK)
|
# initialize data
|
||||||
tx = triton.testing.sparsify_tensor(x, layout, BLOCK)
|
a_shape = (Z, H, M, N)
|
||||||
ty = op(
|
a_ref, a_tri = triton.testing.make_pair(a_shape)
|
||||||
tx,
|
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
|
||||||
scale=scale,
|
# compute [torch]
|
||||||
key_padding_mask=kp_mask,
|
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
|
||||||
key_padding_mask_mode="add",
|
a_ref.retain_grad()
|
||||||
attn_mask=at_mask.to(DTYPE),
|
at_mask = torch.ones((M, N), device="cuda")
|
||||||
attn_mask_mode="mul",
|
|
||||||
is_causal=is_causal,
|
|
||||||
)
|
|
||||||
# torch result
|
|
||||||
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
|
|
||||||
# broadcast at_mask to the same shape as rx
|
|
||||||
if is_causal:
|
if is_causal:
|
||||||
at_mask = torch.tril(at_mask)
|
at_mask = torch.tril(at_mask)
|
||||||
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
|
M = at_mask[None, None, :, :] + torch.zeros_like(a_ref)
|
||||||
rx[M == 0] = float("-inf")
|
a_ref[M == 0] = float("-inf")
|
||||||
# rx += kp_mask[:, None, None, :]
|
out_ref = torch.softmax(a_ref * scale, -1)
|
||||||
ry = torch.softmax(rx * scale, -1)
|
out_ref.backward(dout_ref)
|
||||||
ry = triton.testing.sparsify_tensor(ry, layout, BLOCK)
|
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
|
||||||
|
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
|
||||||
|
# compute [triton]
|
||||||
|
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
|
||||||
|
a_tri.retain_grad()
|
||||||
|
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
|
||||||
|
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
|
||||||
|
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
|
||||||
|
out_tri.backward(dout_tri)
|
||||||
|
da_tri = a_tri.grad
|
||||||
# compare
|
# compare
|
||||||
triton.testing.assert_almost_equal(ry, ty)
|
triton.testing.assert_almost_equal(out_tri, out_ref)
|
||||||
|
triton.testing.assert_almost_equal(da_tri, da_ref)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||||
@@ -99,14 +130,6 @@ def test_attention_fwd_bwd(
|
|||||||
qkvs = [
|
qkvs = [
|
||||||
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
|
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
|
||||||
]
|
]
|
||||||
attn_mask = torch.tril(
|
|
||||||
torch.ones(
|
|
||||||
[n_ctx, n_ctx],
|
|
||||||
device="cuda",
|
|
||||||
dtype=dtype,
|
|
||||||
),
|
|
||||||
diagonal=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Triton:
|
# Triton:
|
||||||
n_blocks = n_ctx // block
|
n_blocks = n_ctx // block
|
||||||
@@ -115,7 +138,7 @@ def test_attention_fwd_bwd(
|
|||||||
query.retain_grad()
|
query.retain_grad()
|
||||||
key.retain_grad()
|
key.retain_grad()
|
||||||
value.retain_grad()
|
value.retain_grad()
|
||||||
attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
|
attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale)
|
||||||
# ad hoc loss
|
# ad hoc loss
|
||||||
loss = (attn_out ** 2).mean()
|
loss = (attn_out ** 2).mean()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@@ -123,6 +146,8 @@ def test_attention_fwd_bwd(
|
|||||||
|
|
||||||
# Torch version:
|
# Torch version:
|
||||||
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
|
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
|
||||||
|
attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype)
|
||||||
|
attn_mask = torch.tril(attn_mask, diagonal=0)
|
||||||
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
|
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
|
||||||
torch_q.retain_grad()
|
torch_q.retain_grad()
|
||||||
torch_k.retain_grad()
|
torch_k.retain_grad()
|
||||||
@@ -147,7 +172,6 @@ def test_attention_fwd_bwd(
|
|||||||
def triton_attention(
|
def triton_attention(
|
||||||
layout,
|
layout,
|
||||||
block: int,
|
block: int,
|
||||||
attn_mask: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
@@ -155,12 +179,9 @@ def triton_attention(
|
|||||||
):
|
):
|
||||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
|
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
|
||||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
|
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
|
||||||
sparse_softmax = triton.ops.blocksparse.softmax(
|
sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
|
||||||
layout,
|
|
||||||
block,
|
|
||||||
)
|
|
||||||
|
|
||||||
w = sparse_dot_sdd_nt(query, key)
|
w = sparse_dot_sdd_nt(query, key)
|
||||||
w = sparse_softmax(w, scale=scale, attn_mask=attn_mask, attn_mask_mode="mul")
|
w = sparse_softmax(w, scale=scale, is_causal=True)
|
||||||
a = sparse_dot_dsd_nn(w, value)
|
a = sparse_dot_dsd_nn(w, value)
|
||||||
return a
|
return a
|
||||||
|
@@ -242,8 +242,31 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
|
|||||||
|
|
||||||
|
|
||||||
def dsd_lut(layout, block, step, trans, device):
|
def dsd_lut(layout, block, step, trans, device):
|
||||||
|
"""
|
||||||
|
Generates the look-up table for incrementing pointers in the DSD/DDS matmul.
|
||||||
|
Example (BLOCK=32, STEP=16)
|
||||||
|
[[1, 0, 0, 1, 0],
|
||||||
|
[0, 1, 1, 0, 1],
|
||||||
|
[1, 0, 1, 0, 0]]
|
||||||
|
|
||||||
|
Then the offsets for A are
|
||||||
|
[0 , 16, 32, 48] <- row 0
|
||||||
|
\\----/ \\----/
|
||||||
|
col=0 col=3
|
||||||
|
[64, 80, 96, 112, 128, 144] <- row 1
|
||||||
|
\\----/ \\----/ \\------/
|
||||||
|
col=1 col=2 col=3
|
||||||
|
[160, 176, 192, 208]
|
||||||
|
which leads to increments table
|
||||||
|
[0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16]
|
||||||
|
|
||||||
|
Because B is dense, the offsets are
|
||||||
|
[0, 16, 96, 112] <- row 0
|
||||||
|
[32, 48, 64, 80] <- row 1
|
||||||
|
[0, 16, 64, 80] <- row 2
|
||||||
|
"""
|
||||||
sizes = torch.sum(layout, 2 if trans else 1)
|
sizes = torch.sum(layout, 2 if trans else 1)
|
||||||
head_id, col_id = sizes.nonzero(as_tuple=True)
|
head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True)
|
||||||
sizes = sizes.flatten()
|
sizes = sizes.flatten()
|
||||||
segments = sizes * step
|
segments = sizes * step
|
||||||
# pointer increments
|
# pointer increments
|
||||||
@@ -258,13 +281,6 @@ def dsd_lut(layout, block, step, trans, device):
|
|||||||
# -------------------------------
|
# -------------------------------
|
||||||
# dense input pointer increments
|
# dense input pointer increments
|
||||||
# -------------------------------
|
# -------------------------------
|
||||||
# given a list of the indices for the first element of each non-zero block.
|
|
||||||
# For example, for the indices
|
|
||||||
# [32, 80, 128, 256, 288]
|
|
||||||
# we would generate the increments
|
|
||||||
# [32, 48, 48, 128, 32]
|
|
||||||
# ^
|
|
||||||
# index of first element
|
|
||||||
# Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
|
# Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
|
||||||
# that is smaller than the block size, so we need to do a bit of extra work
|
# that is smaller than the block size, so we need to do a bit of extra work
|
||||||
# to handle this case
|
# to handle this case
|
||||||
@@ -324,116 +340,11 @@ def dsd_lut(layout, block, step, trans, device):
|
|||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Dense = Dense x Sparse (DDS)
|
# Dense = Dense x Sparse (DDS)
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
# AB = (B^T A^T)^T
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _dds_kernel(
|
|
||||||
A, B, C,
|
|
||||||
stride_za, stride_ha, stride_ma, stride_ka,
|
|
||||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
|
||||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
|
||||||
DS0, DS1, lut,
|
|
||||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
|
||||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
|
|
||||||
):
|
|
||||||
# ------------ #
|
|
||||||
# - Prologue - #
|
|
||||||
# ------------ #
|
|
||||||
pid_m = tl.program_id(0)
|
|
||||||
pid_n = tl.program_id(1)
|
|
||||||
num_pid_m = tl.num_programs(0)
|
|
||||||
num_pid_n = tl.num_programs(1)
|
|
||||||
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
|
|
||||||
pid_z = tl.program_id(2)
|
|
||||||
header = lut + pid_n * 4
|
|
||||||
offset = tl.load(header + 0)
|
|
||||||
AS1 = tl.load(header + 1)
|
|
||||||
column = tl.load(header + 2)
|
|
||||||
off_h = tl.load(header + 3)
|
|
||||||
pinc = lut + offset
|
|
||||||
# initialize pointers to A (dense)
|
|
||||||
offs_am = pid_m * TILE_M + tl.arange(0, TILE_M)
|
|
||||||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M)
|
|
||||||
start_ak = tl.load(pinc)
|
|
||||||
start_ak = tl.multiple_of(start_ak, 8)
|
|
||||||
offs_ak = start_ak + tl.arange(0, TILE_K)
|
|
||||||
ptrs_a = A + pid_z * stride_za \
|
|
||||||
+ off_h * stride_ha \
|
|
||||||
+ offs_am[:, None] * stride_ma \
|
|
||||||
+ offs_ak[None, :] * stride_ka
|
|
||||||
# initialize pointers to B (sparse)
|
|
||||||
block_id = tl.load(pinc + 1)
|
|
||||||
block_id = tl.multiple_of(block_id, 8)
|
|
||||||
offs_bn = tl.arange(0, TILE_N)
|
|
||||||
offs_bk = tl.arange(0, TILE_K)
|
|
||||||
ptrs_b = B + pid_z * stride_zb \
|
|
||||||
+ block_id * stride_hb \
|
|
||||||
+ offs_bn[None, :] * stride_bn \
|
|
||||||
+ offs_bk[:, None] * stride_bk
|
|
||||||
# ---------------- #
|
|
||||||
# Inner Loop #
|
|
||||||
# ---------------- #
|
|
||||||
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
|
||||||
for k in range(AS1, 0, -TILE_K):
|
|
||||||
a = tl.load(ptrs_a, mask=offs_am[:, None] < DS0)
|
|
||||||
b = tl.load(ptrs_b, mask=True)
|
|
||||||
acc += tl.dot(a, b)
|
|
||||||
pinc += 2
|
|
||||||
inc_a = tl.load(pinc)
|
|
||||||
inc_b = tl.load(pinc + 1)
|
|
||||||
inc_a = tl.multiple_of(inc_a, 8)
|
|
||||||
inc_b = tl.multiple_of(inc_b, 8)
|
|
||||||
inc_a = inc_a * stride_ka
|
|
||||||
ptrs_a += inc_a
|
|
||||||
ptrs_b += inc_b
|
|
||||||
# ---------------- #
|
|
||||||
# Epilogue #
|
|
||||||
# ---------------- #
|
|
||||||
c = acc.to(C.dtype.element_ty)
|
|
||||||
# initialize pointers to C (dense)
|
|
||||||
offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
|
|
||||||
offs_cn = column * TILE_N + tl.arange(0, TILE_N)
|
|
||||||
ptrs_c = C + off_h * stride_hc \
|
|
||||||
+ pid_z * stride_zc \
|
|
||||||
+ offs_cm[:, None] * stride_mc \
|
|
||||||
+ offs_cn[None, :] * stride_nc
|
|
||||||
# write back
|
|
||||||
tl.store(ptrs_c, c, mask=offs_cm[:, None] < DS0)
|
|
||||||
|
|
||||||
|
|
||||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
|
||||||
a = a.contiguous()
|
|
||||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
|
||||||
b = b.contiguous()
|
|
||||||
# shapes / dtypes
|
|
||||||
AS0 = a.size(0)
|
|
||||||
AS1 = a.size(1)
|
|
||||||
AS2 = a.size(3 if trans_a else 2)
|
|
||||||
BS2 = block * spdims[1 if trans_b else 2]
|
|
||||||
dtype = a.dtype
|
|
||||||
# output
|
|
||||||
CS0 = AS0
|
|
||||||
CS1 = AS1
|
|
||||||
CS2 = BS2 if trans_c else AS2
|
|
||||||
CS3 = AS2 if trans_c else BS2
|
|
||||||
if out is None:
|
|
||||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
|
||||||
else:
|
|
||||||
assert out.shape == (CS0, CS1, CS2, CS3)
|
|
||||||
c = out
|
|
||||||
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
|
|
||||||
grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0]
|
|
||||||
_dds_kernel[grid](
|
|
||||||
a, b, c,
|
|
||||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
|
||||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
|
||||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
|
||||||
AS2, BS2, lut,
|
|
||||||
TILE_M=TILE_M, TILE_N=block, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
|
|
||||||
num_warps=4, GROUP_SIZE_M=4,
|
|
||||||
)
|
|
||||||
return c
|
|
||||||
|
|
||||||
##############
|
##############
|
||||||
# MAIN API #
|
# MAIN API #
|
||||||
|
@@ -5,230 +5,235 @@ import triton.language as tl
|
|||||||
|
|
||||||
|
|
||||||
def num_warps(n):
|
def num_warps(n):
|
||||||
if n < 512:
|
if n <= 128:
|
||||||
|
return 1
|
||||||
|
if n <= 256:
|
||||||
|
return 2
|
||||||
|
if n <= 512:
|
||||||
return 4
|
return 4
|
||||||
if n < 2048:
|
if n <= 4096:
|
||||||
return 8
|
return 8
|
||||||
return 16
|
return 16
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
|
|
||||||
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax'] * nargs['BLOCK'])})
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _forward(
|
def _blocksparse_softmax_fwd(
|
||||||
X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
Out, A, stride_xz, LUT,
|
||||||
TN: tl.constexpr, BLOCK: tl.constexpr, APPLY_SCALE: tl.constexpr, APPLY_RPE: tl.constexpr, APPLY_KP_MASK: tl.constexpr,
|
R, extent, stride_zr, stride_hr, # relative attention
|
||||||
KP_MASK_MUL: tl.constexpr, APPLY_ATTN_MASK: tl.constexpr, ATTN_MASK_MUL: tl.constexpr,
|
scale, is_causal,
|
||||||
|
ROW_SIZE: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
IS_DENSE: tl.constexpr,
|
||||||
):
|
):
|
||||||
pidhm = tl.program_id(0)
|
h = tl.program_id(0)
|
||||||
pidz = tl.program_id(1)
|
m = tl.program_id(1)
|
||||||
|
z = tl.program_id(2)
|
||||||
# create index ranges
|
# create index ranges
|
||||||
rxm = pidhm % BLOCK
|
hm = h * tl.num_programs(1) + m
|
||||||
rbm = pidhm // BLOCK
|
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
|
||||||
rxn = tl.arange(0, TN) % BLOCK
|
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
|
||||||
rbn = tl.arange(0, TN) // BLOCK
|
|
||||||
# extract information from LUT
|
# extract information from LUT
|
||||||
header = LUT + rbm * 2
|
header = LUT + (hm // BLOCK_SIZE) * 2
|
||||||
size = tl.load(header + 0)
|
size = tl.load(header + 0)
|
||||||
offset = tl.load(header + 1)
|
offset = tl.load(header + 1)
|
||||||
check = rbn < size
|
# pointer offset
|
||||||
rbmn = tl.where(check, rbn, size - 1)
|
off_a = z * stride_xz
|
||||||
# block id and column id
|
off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx
|
||||||
blockid = tl.load(LUT + offset + rbmn * 4 + 0)
|
off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx
|
||||||
columnid = tl.load(LUT + offset + rbmn * 4 + 1)
|
# do not need to read column indices in the dense case
|
||||||
rowid = tl.load(LUT + offset + rbmn * 4 + 2)
|
if IS_DENSE:
|
||||||
headid = tl.load(LUT + offset + rbmn * 4 + 3)
|
ns = tl.arange(0, ROW_SIZE)
|
||||||
# pointers to X
|
else:
|
||||||
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
|
||||||
x = tl.load(px, mask=check, other=-float('inf'))
|
start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)
|
||||||
x = x.to(tl.float32)
|
ns = start_n * BLOCK_SIZE + lane_n
|
||||||
# apply scale
|
# load X
|
||||||
if APPLY_SCALE:
|
mask = block_n < size
|
||||||
x = x * scale
|
a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf"))
|
||||||
# apply RPE
|
a = a.to(tl.float32)
|
||||||
if APPLY_RPE:
|
# compute
|
||||||
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
|
out = a
|
||||||
rpe = tl.load(prpe, mask=check, other=0)
|
out *= scale
|
||||||
x = x + rpe
|
# apply relative attention
|
||||||
# apply key-padding mask
|
if R is not None:
|
||||||
if APPLY_KP_MASK:
|
R += z * stride_zr
|
||||||
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
|
R += h * stride_hr
|
||||||
kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
|
off_lo = (extent - m - 1) + ns
|
||||||
if KP_MASK_MUL:
|
mask_lo = (off_lo >= 0) & (off_lo < extent)
|
||||||
kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
|
rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0)
|
||||||
x = x + kp_m
|
out += rel_logits
|
||||||
# apply attention mask
|
out = out.to(tl.float32)
|
||||||
if APPLY_ATTN_MASK:
|
|
||||||
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
|
|
||||||
attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
|
|
||||||
if ATTN_MASK_MUL:
|
|
||||||
attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
|
|
||||||
x = x + attn_m
|
|
||||||
# apply causal mask
|
# apply causal mask
|
||||||
is_in_upper_triangle = columnid * BLOCK + rxn > rowid * BLOCK + rxm
|
out = tl.where((ns > m) & is_causal, -float("inf"), out)
|
||||||
x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.)
|
|
||||||
# computation
|
# computation
|
||||||
x = tl.softmax(x)
|
out = tl.softmax(out)
|
||||||
tl.store(px, x, mask=check)
|
# write-back
|
||||||
|
tl.store(Out + off_a + lane_n, out, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
|
|
||||||
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax']) * nargs['BLOCK']})
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr):
|
def _blocksparse_softmax_bwd(
|
||||||
pidhm = tl.program_id(0)
|
DA, stride_zdx,
|
||||||
pidz = tl.program_id(1)
|
DOut, stride_zdout,
|
||||||
|
Out, stride_zout,
|
||||||
|
scale,
|
||||||
|
LUT,
|
||||||
|
DR, extent, stride_zr, stride_hr, stride_er,
|
||||||
|
is_causal,
|
||||||
|
ROW_SIZE: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
IS_DENSE: tl.constexpr,
|
||||||
|
):
|
||||||
|
h = tl.program_id(0)
|
||||||
|
m = tl.program_id(1)
|
||||||
|
z = tl.program_id(2)
|
||||||
# create index ranges
|
# create index ranges
|
||||||
rxm = pidhm % BLOCK
|
hm = h * tl.num_programs(1) + m
|
||||||
rbm = pidhm // BLOCK
|
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
|
||||||
rxn = tl.arange(0, TN) % BLOCK
|
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
|
||||||
rbn = tl.arange(0, TN) // BLOCK
|
# extract information from LUT
|
||||||
# extract information from look-up table
|
header = LUT + (hm // BLOCK_SIZE) * 2
|
||||||
header = LUT + rbm * 2
|
|
||||||
size = tl.load(header + 0)
|
size = tl.load(header + 0)
|
||||||
offset = tl.load(header + 1)
|
offset = tl.load(header + 1)
|
||||||
# bounds checking on lut
|
# row-col offset
|
||||||
check = rbn < size
|
off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE
|
||||||
rbmn = tl.where(check, rbn, size - 1)
|
off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE
|
||||||
# initialize pointers to block-sparse input
|
mask = block_n < size
|
||||||
blockid = tl.load(LUT + offset + rbmn * 4)
|
# pointers
|
||||||
X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
As = Out + z * stride_zout + off_mn
|
||||||
DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
DOuts = DOut + z * stride_zdout + off_mn
|
||||||
# compute fused softmax backward
|
# do not need to read column indices in the dense case
|
||||||
x = tl.load(X, mask=check, other=0)
|
if IS_DENSE:
|
||||||
dx = tl.load(DX, mask=check, other=0)
|
ns = tl.arange(0, ROW_SIZE)
|
||||||
x = x.to(tl.float32)
|
else:
|
||||||
dx = dx.to(tl.float32)
|
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
|
||||||
y = x * (dx - tl.sum(x * dx, 0)) * scale
|
start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0)
|
||||||
tl.store(DX, y, mask=check)
|
ns = start_n * BLOCK_SIZE + lane_n
|
||||||
|
# load data
|
||||||
|
a = tl.load(As + lane_n, mask=mask, other=0.0)
|
||||||
|
a = a.to(tl.float32)
|
||||||
|
dout = tl.load(DOuts + lane_n, mask=mask, other=0.0)
|
||||||
|
dout = dout.to(tl.float32)
|
||||||
|
# compute
|
||||||
|
da = a * (dout - tl.sum(a * dout, 0))
|
||||||
|
da = tl.where((ns > m) & is_causal, 0., da)
|
||||||
|
# apply relative attention
|
||||||
|
if DR is not None:
|
||||||
|
DR += z * stride_zr
|
||||||
|
DR += h * stride_hr
|
||||||
|
off_lo = (extent - m - 1) + ns
|
||||||
|
mask_lo = (off_lo >= 0) & (off_lo < extent) & mask
|
||||||
|
tl.store(DR + m * extent + off_lo, da, mask=mask_lo)
|
||||||
|
da = da * scale
|
||||||
|
# convert da
|
||||||
|
# write-back
|
||||||
|
DAs = DA + z * stride_zdx + off_mn
|
||||||
|
tl.store(DAs + lane_n, da, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
class _softmax(torch.autograd.Function):
|
class _softmax(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_lut(layout, block, device):
|
def make_lut(layout, block, device):
|
||||||
|
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||||
|
sizes = _empty.clone()
|
||||||
# sizes along rows
|
# sizes along rows
|
||||||
sizes = layout.sum(-1).view(-1)
|
for h in range(layout.shape[0]):
|
||||||
|
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
|
||||||
|
total_sizes = sizes * block
|
||||||
# offsets in block format
|
# offsets in block format
|
||||||
offsets = torch.zeros_like(sizes)
|
offsets = torch.zeros_like(sizes)
|
||||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||||
# block indices
|
# block indices
|
||||||
layout_sum = sizes.sum()
|
columns = layout.nonzero(as_tuple=False)[:, 2]
|
||||||
idx = torch.arange(layout_sum, device=layout.device)
|
|
||||||
layout_nonzero = layout.nonzero(as_tuple=False)
|
|
||||||
head = layout_nonzero[:, 0]
|
|
||||||
rows = layout_nonzero[:, 1]
|
|
||||||
columns = layout_nonzero[:, 2]
|
|
||||||
core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
|
|
||||||
# construct look-up table
|
|
||||||
offsets = offsets * 4 + 2 * sizes.numel()
|
|
||||||
header = torch.stack((sizes, offsets), dim=1).view(-1)
|
header = torch.stack((sizes, offsets), dim=1).view(-1)
|
||||||
lut = torch.cat((header, core)).type(torch.int32).to(device)
|
lut = torch.cat((header, columns)).type(torch.int32).to(device)
|
||||||
return lut, int(sizes.max())
|
return lut, int(total_sizes.max())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(
|
def forward(
|
||||||
ctx, x, scale, rpe,
|
ctx, a, scale, rel_logits, is_causal,
|
||||||
key_padding_mask, attn_mask,
|
spdims, block, lut, maxlut, is_dense
|
||||||
kp_mask_mode, attn_mask_mode,
|
|
||||||
is_causal,
|
|
||||||
spdims, block, lut, maxlut
|
|
||||||
):
|
):
|
||||||
apply_scale = False if scale == 1.0 else True
|
if scale is not None and isinstance(scale, torch.Tensor):
|
||||||
# handle None rpe
|
assert scale.device.type == "cpu"
|
||||||
if rpe is None:
|
scale = scale.item()
|
||||||
apply_rpe = False
|
M = a.shape[0]
|
||||||
stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0
|
grid = [spdims[0], spdims[1] * block, M]
|
||||||
rpe = torch.empty(0, dtype=x.dtype, device=x.device)
|
rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape
|
||||||
else:
|
rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride()
|
||||||
apply_rpe = True
|
# enqueue kernel
|
||||||
stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)
|
out = torch.empty_like(a)
|
||||||
# handle None key_padding_mask
|
_blocksparse_softmax_fwd[grid](
|
||||||
if key_padding_mask is None:
|
out, a, a.stride(0), lut,
|
||||||
apply_kp_mask = False
|
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
||||||
stride_zkpm = 0
|
scale,
|
||||||
key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)
|
is_causal,
|
||||||
else:
|
BLOCK_SIZE=block,
|
||||||
apply_kp_mask = True
|
ROW_SIZE=triton.next_power_of_2(maxlut),
|
||||||
stride_zkpm = key_padding_mask.stride(0)
|
IS_DENSE=is_dense,
|
||||||
# handle None attention_mask
|
num_warps=num_warps(maxlut)
|
||||||
if attn_mask is None:
|
)
|
||||||
apply_attn_mask = False
|
|
||||||
stride_zattnm = 0
|
|
||||||
attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)
|
|
||||||
else:
|
|
||||||
apply_attn_mask = True
|
|
||||||
stride_zattnm = attn_mask.stride(0)
|
|
||||||
# run kernel
|
|
||||||
M = x.shape[0]
|
|
||||||
grid = [spdims[0] * spdims[1] * block, M]
|
|
||||||
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),
|
|
||||||
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
|
||||||
BLOCK=block,
|
|
||||||
APPLY_SCALE=apply_scale,
|
|
||||||
APPLY_RPE=apply_rpe,
|
|
||||||
APPLY_KP_MASK=apply_kp_mask,
|
|
||||||
APPLY_ATTN_MASK=apply_attn_mask,
|
|
||||||
KP_MASK_MUL=(kp_mask_mode == 'mul'),
|
|
||||||
ATTN_MASK_MUL=(attn_mask_mode == 'mul'))
|
|
||||||
# save to context
|
# save to context
|
||||||
ctx.mark_dirty(x)
|
# ctx.mark_dirty(x)
|
||||||
ctx.save_for_backward(x, lut)
|
ctx.save_for_backward(out, lut)
|
||||||
ctx.spdims = spdims
|
ctx.spdims = spdims
|
||||||
ctx.block = block
|
ctx.block = block
|
||||||
ctx.maxlut = maxlut
|
ctx.maxlut = maxlut
|
||||||
ctx.scale = scale
|
ctx.scale = scale
|
||||||
ctx.apply_scale = apply_scale
|
ctx.rel_shape = rel_shape
|
||||||
ctx.apply_rpe = apply_rpe
|
ctx.rel_strides = rel_strides
|
||||||
ctx.apply_kp_mask = apply_kp_mask
|
ctx.rel_dtype = a.dtype
|
||||||
ctx.apply_attn_mask = apply_attn_mask
|
ctx.is_dense = is_dense
|
||||||
ctx.kp_mask_mode = kp_mask_mode
|
ctx.is_causal = is_causal
|
||||||
ctx.attn_mask_mode = attn_mask_mode
|
return out
|
||||||
return x
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, dx):
|
def backward(ctx, dout):
|
||||||
# retrieve from context
|
# retrieve from context
|
||||||
x, lut = ctx.saved_tensors
|
out, lut = ctx.saved_tensors
|
||||||
|
# relative logits gradients
|
||||||
|
dr = None
|
||||||
|
if ctx.needs_input_grad[3]:
|
||||||
|
dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device)
|
||||||
# run kernel
|
# run kernel
|
||||||
M = x.shape[0]
|
M = out.shape[0]
|
||||||
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
|
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
|
||||||
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
|
da = torch.empty_like(dout)
|
||||||
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
_blocksparse_softmax_bwd[grid](
|
||||||
|
da, da.stride(0),
|
||||||
|
dout, dout.stride(0),
|
||||||
|
out, out.stride(0),
|
||||||
|
ctx.scale,
|
||||||
|
lut,
|
||||||
|
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
|
||||||
|
ctx.is_causal,
|
||||||
|
BLOCK_SIZE=ctx.block,
|
||||||
|
ROW_SIZE=triton.next_power_of_2(ctx.maxlut),
|
||||||
|
IS_DENSE=ctx.is_dense,
|
||||||
|
num_warps=num_warps(ctx.maxlut)
|
||||||
|
)
|
||||||
|
return (da, None, None, dr, None,
|
||||||
|
None, None, None, None, None,
|
||||||
|
None,
|
||||||
|
None, None, None,
|
||||||
|
None,
|
||||||
|
None, None, None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class softmax:
|
class softmax:
|
||||||
|
def __init__(self, layout, block, device, is_dense=False):
|
||||||
def make_lut(self, device):
|
|
||||||
key = (device, )
|
|
||||||
if key not in self.lut_cache:
|
|
||||||
self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
|
|
||||||
return self.lut_cache[key]
|
|
||||||
|
|
||||||
def __init__(self, layout, block):
|
|
||||||
self.spdims = layout.shape
|
self.spdims = layout.shape
|
||||||
self.layout = layout
|
self.layout = layout
|
||||||
self.block = block
|
self.block = block
|
||||||
self.lut_cache = dict()
|
self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device)
|
||||||
|
self.is_dense = is_dense
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
|
||||||
self, x, scale=1., rpe=None,
|
if rel_logits is not None and rel_logits.dtype != a.dtype:
|
||||||
key_padding_mask=None, attn_mask=None,
|
raise ValueError("relative position embedding must be %s" % a.dtype)
|
||||||
key_padding_mask_mode='add', attn_mask_mode='add',
|
a = _softmax.apply(
|
||||||
is_causal=False
|
a, scale, rel_logits, is_causal,
|
||||||
):
|
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
|
||||||
if rpe is not None and rpe.dtype != x.dtype:
|
|
||||||
raise ValueError('relative position embedding must be %s' % x.dtype)
|
|
||||||
if attn_mask is not None and attn_mask.dtype != x.dtype:
|
|
||||||
raise ValueError('Attention mask must be %s' % x.dtype)
|
|
||||||
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
|
|
||||||
raise ValueError('Key padding mask must be %s' % x.dtype)
|
|
||||||
lut, maxlut = self.make_lut(x.device)
|
|
||||||
x = _softmax.apply(
|
|
||||||
x, scale, rpe,
|
|
||||||
key_padding_mask, attn_mask,
|
|
||||||
key_padding_mask_mode, attn_mask_mode,
|
|
||||||
is_causal,
|
|
||||||
self.spdims, self.block,
|
|
||||||
lut, maxlut
|
|
||||||
)
|
)
|
||||||
return x
|
return a
|
||||||
|
@@ -32,6 +32,19 @@ def sparsify_tensor(x, mask, block):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None):
|
||||||
|
if data is None:
|
||||||
|
data = torch.randn(shape, dtype=torch.float32, device=device)
|
||||||
|
ref_ret = data
|
||||||
|
ref_ret = ref_ret * alpha + beta
|
||||||
|
ref_ret = ref_ret.half().float()
|
||||||
|
if trans:
|
||||||
|
ref_ret = ref_ret.t().requires_grad_()
|
||||||
|
ref_ret = ref_ret.detach().requires_grad_()
|
||||||
|
tri_ret = ref_ret.clone().detach().requires_grad_()
|
||||||
|
return ref_ret, tri_ret
|
||||||
|
|
||||||
|
|
||||||
def cutlass_matmul(a, b):
|
def cutlass_matmul(a, b):
|
||||||
if _cutlass is None:
|
if _cutlass is None:
|
||||||
raise RuntimeError("Cannot find cutlass library")
|
raise RuntimeError("Cannot find cutlass library")
|
||||||
|
Reference in New Issue
Block a user