[LANG] Added support for constexpr (#361)

This commit is contained in:
Philippe Tillet
2021-10-30 00:32:58 -07:00
committed by GitHub
parent 770ea96cca
commit 2acaa4d0dd
16 changed files with 355 additions and 365 deletions

View File

@@ -21,7 +21,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
# triton result
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest)
@@ -151,8 +151,8 @@ def triton_attention(
value: torch.Tensor,
scale: float,
):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
sparse_softmax = triton.ops.blocksparse.softmax(
layout,
block,