[FRONTEND] assert that num_warps is a power of 2 (#539)
This commit is contained in:
@@ -1155,3 +1155,17 @@ def test_if():
|
||||
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
|
||||
ret = torch.empty(1, dtype=torch.float32, device='cuda')
|
||||
kernel[(1,)](cond, x_true, x_false, ret)
|
||||
|
||||
|
||||
def test_num_warps_pow2():
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst):
|
||||
pass
|
||||
|
||||
with pytest.raises(AssertionError, match='must be a power of 2'):
|
||||
_kernel[(1,)](dst=dst, num_warps=3)
|
||||
_kernel[(1,)](dst=dst, num_warps=1)
|
||||
_kernel[(1,)](dst=dst, num_warps=2)
|
||||
_kernel[(1,)](dst=dst, num_warps=4)
|
||||
|
@@ -954,6 +954,7 @@ class Kernel:
|
||||
return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
||||
assert num_warps != 0 and (num_warps & (num_warps - 1)) == 0, f"{num_warps=} must be a power of 2."
|
||||
# handle arguments passed by name
|
||||
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
|
||||
wargs = list(wargs)
|
||||
|
Reference in New Issue
Block a user