[FRONTEND] assert that num_warps is a power of 2 (#539)
This commit is contained in:
@@ -1155,3 +1155,17 @@ def test_if():
|
||||
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
|
||||
ret = torch.empty(1, dtype=torch.float32, device='cuda')
|
||||
kernel[(1,)](cond, x_true, x_false, ret)
|
||||
|
||||
|
||||
def test_num_warps_pow2():
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst):
|
||||
pass
|
||||
|
||||
with pytest.raises(AssertionError, match='must be a power of 2'):
|
||||
_kernel[(1,)](dst=dst, num_warps=3)
|
||||
_kernel[(1,)](dst=dst, num_warps=1)
|
||||
_kernel[(1,)](dst=dst, num_warps=2)
|
||||
_kernel[(1,)](dst=dst, num_warps=4)
|
||||
|
Reference in New Issue
Block a user