[PYTHON] Add Blocksparse Attention Fwd/Bwd Test (#69)

Also includes small bugfix for block-sparse softmax
This commit is contained in:
Jared Kaplan
2021-02-19 17:46:05 -05:00
committed by Philippe Tillet
parent 7aa4d080b3
commit 045ab5d62a
3 changed files with 149 additions and 42 deletions

View File

@@ -3,44 +3,57 @@ import triton
import pytest import pytest
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK", @pytest.mark.parametrize(
"MODE, TRANS_A, TRANS_B, BLOCK",
[ [
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\ (mode, at, bt, block)
for at in [False, True]\ for mode in ["sdd", "dsd", "dds"]
for bt in [False, True]\ for at in [False, True]
for block in [16, 32, 64] for bt in [False, True]
] for block in [16, 32, 64]
) ],
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384): )
def test_matmul(
MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384
):
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
# create inputs # create inputs
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda') a = torch.randn(
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda') (Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda"
shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE] )
b = torch.randn(
(Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda"
)
shape = {
"sdd": (M, N),
"dsd": (a.shape[2], a.shape[3]),
"dds": (b.shape[2], b.shape[3]),
}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
# triton result # triton result
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B) op = triton.ops.blocksparse.matmul(
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b )
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
rc = op(ra, rb) rc = op(ra, rb)
# torch result # torch result
ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a
tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b
ta = ta.transpose(2, 3) if TRANS_A else ta ta = ta.transpose(2, 3) if TRANS_A else ta
tb = tb.transpose(2, 3) if TRANS_B else tb tb = tb.transpose(2, 3) if TRANS_B else tb
tc = torch.matmul(ta, tb) tc = torch.matmul(ta, tb)
tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
# compare # compare
assert triton.testing.allclose(rc, tc) assert triton.testing.allclose(rc, tc)
@pytest.mark.parametrize("BLOCK, WIDTH",
[ @pytest.mark.parametrize(
(block, width) for block in [32]\ "BLOCK, WIDTH",
for width in [256, 576, 1024, 1792] [(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
] )
)
def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16): def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
@@ -48,27 +61,31 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
scale = 0.4 scale = 0.4
# create inputs # create inputs
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device='cuda') x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
at_mask = torch.randint(low=0, high=2, size=(N, N), \ at_mask = torch.randint(
dtype=torch.bool, requires_grad=False, device='cuda') low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda"
kp_mask = torch.randint(low=0, high=2, size=(Z, N), \ )
dtype=DTYPE, requires_grad=False, device='cuda') kp_mask = torch.randint(
kp_mask[kp_mask == 1.] = float('-inf') low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda"
)
kp_mask[kp_mask == 1.0] = float("-inf")
# triton result # triton result
op = triton.ops.blocksparse.softmax(layout, BLOCK) op = triton.ops.blocksparse.softmax(layout, BLOCK)
tx = triton.testing.sparsify_tensor(x, layout, BLOCK) tx = triton.testing.sparsify_tensor(x, layout, BLOCK)
ty = op(tx, ty = op(
scale=scale, tx,
key_padding_mask=kp_mask, scale=scale,
key_padding_mask_mode='add', key_padding_mask=kp_mask,
attn_mask=at_mask.to(DTYPE), key_padding_mask_mode="add",
attn_mask_mode='mul') attn_mask=at_mask.to(DTYPE),
attn_mask_mode="mul",
)
# torch result # torch result
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float('-inf')) rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
if at_mask is not None: if at_mask is not None:
# broadcast at_mask to the same shape as rx # broadcast at_mask to the same shape as rx
M = at_mask[None, None, :, :] + torch.zeros_like(rx) M = at_mask[None, None, :, :] + torch.zeros_like(rx)
rx[M == 0] = float('-inf') rx[M == 0] = float("-inf")
if kp_mask is not None: if kp_mask is not None:
rx += kp_mask[:, None, None, :] rx += kp_mask[:, None, None, :]
ry = torch.softmax(rx * scale, -1) ry = torch.softmax(rx * scale, -1)
@@ -76,3 +93,93 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
ry = triton.testing.sparsify_tensor(ry, layout, BLOCK) ry = triton.testing.sparsify_tensor(ry, layout, BLOCK)
# compare # compare
assert triton.testing.allclose(ry, ty) assert triton.testing.allclose(ry, ty)
def test_attention_fwd_bwd(
input_scale=1.0,
tol=2e-2,
scale=1 / 8.0,
n_ctx=256,
dtype=torch.float16,
batch_size=2,
n_heads=2,
block=64,
):
# inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True)
.to(dtype)
.cuda()
for _ in range(3)
]
attn_mask = torch.tril(
torch.ones(
[n_ctx, n_ctx],
device="cuda",
dtype=dtype,
),
diagonal=0,
)
# Triton:
n_blocks = n_ctx // block
layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long))
query, key, value = [x.clone() for x in qkvs]
query.retain_grad()
key.retain_grad()
value.retain_grad()
attn_out = triton_attention(
layout, block, attn_mask, query=query, key=key, value=value, scale=scale
)
# ad hoc loss
loss = (attn_out ** 2).mean()
loss.backward()
grads = [query.grad, key.grad, value.grad]
# Torch version:
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
torch_q.retain_grad()
torch_k.retain_grad()
torch_v.retain_grad()
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
scores = scores + attn_mask
probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
# ad hoc loss
torch_loss = (torch_attn_out ** 2).mean()
torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
# comparison
print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol)
for g1, g2 in zip(grads, torch_grads):
torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol)
def triton_attention(
layout,
block: int,
attn_mask: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(
layout, block, "sdd", trans_a=False, trans_b=True
)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(
layout, block, "dsd", trans_a=False, trans_b=False
)
sparse_softmax = triton.ops.blocksparse.softmax(
layout,
block,
)
w = sparse_dot_sdd_nt(query, key)
w = sparse_softmax(w, scale=scale, attn_mask=attn_mask, attn_mask_mode="mul")
a = sparse_dot_dsd_nn(w, value)
return a

View File

@@ -126,7 +126,7 @@ class _matmul(torch.autograd.Function):
num_lock = 1 num_lock = 1
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple) key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
if key not in _matmul.sdd_cache: if key not in _matmul.sdd_cache:
defines = {'TM': block*pack, 'TN': block*pack, defines = {'TM': block*pack, 'TN': block*pack,
'TMN': block*block*pack*pack, 'TMN': block*block*pack*pack,
'BLOCK': block, 'BLOCK': block,
'TK': 32, 'TK': 32,

View File

@@ -110,7 +110,7 @@ class _softmax(torch.autograd.Function):
kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut * block, x.device, x.dtype, block, apply_scale, kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut * block, x.device, x.dtype, block, apply_scale,
apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode) apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
M = x.shape[0] M = x.shape[0]
grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M] grid = lambda opt: [spdims[0] * spdims[1] * block, M]
# run kernel # run kernel
kernel(x.data_ptr(), kernel(x.data_ptr(),
@@ -151,7 +151,7 @@ class _softmax(torch.autograd.Function):
ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask, ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask,
ctx.kp_mask_mode, ctx.attn_mask_mode) ctx.kp_mask_mode, ctx.attn_mask_mode)
M = x.shape[0] M = x.shape[0]
grid = lambda opt: [triton.cdiv(ctx.spdims[0] * ctx.spdims[1] * ctx.block, opt.TM), M] grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
kernel(x.data_ptr(), ctx.scale, dx.data_ptr(), lut.data_ptr(), ctx.maxlut, x.stride(0), dx.stride(0), grid=grid) kernel(x.data_ptr(), ctx.scale, dx.data_ptr(), lut.data_ptr(), ctx.maxlut, x.stride(0), dx.stride(0), grid=grid)
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None