[FRONTEND] assert that num_warps is a power of 2 (#539)

This commit is contained in:
TC
2022-06-06 14:37:08 -04:00
committed by GitHub
parent 751e325d2e
commit f13cbaab9f
2 changed files with 15 additions and 0 deletions

View File

@@ -1155,3 +1155,17 @@ def test_if():
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
ret = torch.empty(1, dtype=torch.float32, device='cuda')
kernel[(1,)](cond, x_true, x_false, ret)
def test_num_warps_pow2():
dst = torch.empty(128, device='cuda')
@triton.jit
def _kernel(dst):
pass
with pytest.raises(AssertionError, match='must be a power of 2'):
_kernel[(1,)](dst=dst, num_warps=3)
_kernel[(1,)](dst=dst, num_warps=1)
_kernel[(1,)](dst=dst, num_warps=2)
_kernel[(1,)](dst=dst, num_warps=4)