[PYTHON][OPS] Added block-sparse softmax

This commit is contained in:
Philippe Tillet
2021-01-30 19:58:42 -05:00
parent f81da73b6a
commit 7ba242fcce
4 changed files with 373 additions and 2 deletions

View File

@@ -23,7 +23,7 @@ def mask_tensor(x, mask, block, value = 0):
for block in [16, 32, 64]
]
)
def test_op(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2, M = 128, N = 256, K = 384):
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2, M = 128, N = 256, K = 384):
# set seed
torch.random.manual_seed(0)
# create inputs
@@ -48,3 +48,42 @@ def test_op(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2,
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(rc, tc, rtol=rtol, atol=atol)
@pytest.mark.parametrize("BLOCK, WIDTH",
[
(block, width) for block in [16]\
for width in [256, 576]
]
)
def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
# set seed
torch.random.manual_seed(0)
Z, H, M, N = 2, 4, WIDTH, WIDTH
scale = 0.4
# create inputs
layout = torch.randint(2, (H, M//BLOCK, N//BLOCK))
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device='cuda')
at_mask = torch.randint(low=0, high=2, size=(N, N), \
dtype=torch.bool, requires_grad=False, device='cuda')
kp_mask = torch.randint(low=0, high=2, size=(Z, N), \
dtype=DTYPE, requires_grad=False, device='cuda')
kp_mask[kp_mask==1.] = float('-inf')
# triton result
op = tt.ops.blocksparse.softmax(layout, BLOCK)
tx = sparsify_tensor(x, layout, BLOCK)
ty = op(tx, scale=scale, key_padding_mask=kp_mask, key_padding_mask_mode='add', attn_mask=at_mask.to(DTYPE), attn_mask_mode='mul')
# torch result
rx = mask_tensor(x, layout, BLOCK, value=float('-inf'))
if at_mask is not None:
# broadcast at_mask to the same shape as rx
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
rx[M == 0] = float('-inf')
if kp_mask is not None:
rx += kp_mask[:, None, None, :]
ry = torch.softmax(rx*scale, -1)
ry = sparsify_tensor(ry, layout, BLOCK)
# compare
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(ry , ty, rtol=rtol, atol=atol)