From c82a2066841208a048335e019517c2530944e34b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 26 May 2022 17:41:09 -0700 Subject: [PATCH] [FRONTEND] Better dot error message (#531) --- python/test/unit/language/test_core.py | 2 +- python/triton/language/core.py | 20 ++++++++++++++++++++ python/triton/language/semantic.py | 4 ++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 952922f6b..50bfb9d1c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -910,7 +910,7 @@ def test_arange(start, device='cuda'): def test_masked_load_shared_memory(dtype, device='cuda'): M = 32 N = 32 - K = 8 + K = 16 in1 = torch.rand((M, K), dtype=dtype, device=device) in2 = torch.rand((K, N), dtype=dtype, device=device) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 7ef63abba..f81645a36 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -342,6 +342,26 @@ class constexpr: def __bool__(self): return bool(self.value) + def __ge__(self, other): + other = other.value if isinstance(other, constexpr) else other + return self.value >= other + + def __gt__(self, other): + other = other.value if isinstance(other, constexpr) else other + return self.value > other + + def __le__(self, other): + other = other.value if isinstance(other, constexpr) else other + return self.value <= other + + def __lt__(self, other): + other = other.value if isinstance(other, constexpr) else other + return self.value < other + + def __eq__(self, other): + other = other.value if isinstance(other, constexpr) else other + return self.value == other + def __call__(self, *args, **kwds): return self.value(*args, **kwds) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 10d20fbb3..2af25cbb2 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -905,6 +905,10 @@ def dot(lhs: tl.tensor, allow_tf32: bool, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() + assert len(lhs.shape) == 2 and len(rhs.shape) == 2 + assert lhs.shape[-1] == rhs.shape[0] + assert lhs.shape[0] >= 16 and lhs.shape[1] >= 16 and rhs.shape[1] >= 16,\ + "small blocks not supported!" if lhs.type.scalar.is_int(): _0 = builder.get_int32(0) ret_scalar_ty = tl.int32