diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 75d75e8ea..046e60e0e 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -380,6 +380,9 @@ class tensor: self.numel = 1 for s in self.shape: self.numel *= s + is_pow2 = (self.numel and (not(self.numel & (self.numel - 1)))) + if not is_pow2: + raise ValueError("Triton tensors must have a power-of-two number of elements") self.numel = constexpr(self.numel) self.type = type # Tensor type (can be block_type) # Following the practice in pytorch, dtype is scalar type