From f13cbaab9fc81fa55a799f9cc76b0dab680f80d7 Mon Sep 17 00:00:00 2001 From: TC <93944281+tomconerlyanth@users.noreply.github.com> Date: Mon, 6 Jun 2022 14:37:08 -0400 Subject: [PATCH] [FRONTEND] assert that num_warps is a power of 2 (#539) --- python/test/unit/language/test_core.py | 14 ++++++++++++++ python/triton/code_gen.py | 1 + 2 files changed, 15 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 71df6d73b..b2b2cdeb1 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1155,3 +1155,17 @@ def test_if(): x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda') ret = torch.empty(1, dtype=torch.float32, device='cuda') kernel[(1,)](cond, x_true, x_false, ret) + + +def test_num_warps_pow2(): + dst = torch.empty(128, device='cuda') + + @triton.jit + def _kernel(dst): + pass + + with pytest.raises(AssertionError, match='must be a power of 2'): + _kernel[(1,)](dst=dst, num_warps=3) + _kernel[(1,)](dst=dst, num_warps=1) + _kernel[(1,)](dst=dst, num_warps=2) + _kernel[(1,)](dst=dst, num_warps=4) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index b64c7eb86..01cd1b5ed 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -954,6 +954,7 @@ class Kernel: return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False) def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): + assert num_warps != 0 and (num_warps & (num_warps - 1)) == 0, f"{num_warps=} must be a power of 2." # handle arguments passed by name kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} wargs = list(wargs)