[FRONTEND] assert that num_warps is a power of 2 (#539)

This commit is contained in:
TC
2022-06-06 14:37:08 -04:00
committed by GitHub
parent 751e325d2e
commit f13cbaab9f
2 changed files with 15 additions and 0 deletions

View File

@@ -1155,3 +1155,17 @@ def test_if():
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
ret = torch.empty(1, dtype=torch.float32, device='cuda')
kernel[(1,)](cond, x_true, x_false, ret)
def test_num_warps_pow2():
dst = torch.empty(128, device='cuda')
@triton.jit
def _kernel(dst):
pass
with pytest.raises(AssertionError, match='must be a power of 2'):
_kernel[(1,)](dst=dst, num_warps=3)
_kernel[(1,)](dst=dst, num_warps=1)
_kernel[(1,)](dst=dst, num_warps=2)
_kernel[(1,)](dst=dst, num_warps=4)

View File

@@ -954,6 +954,7 @@ class Kernel:
return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False)
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
assert num_warps != 0 and (num_warps & (num_warps - 1)) == 0, f"{num_warps=} must be a power of 2."
# handle arguments passed by name
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
wargs = list(wargs)