[FRONTEND] Better dot error message (#531)

This commit is contained in:
Philippe Tillet
2022-05-26 17:41:09 -07:00
committed by GitHub
parent 0e2883020a
commit c82a206684
3 changed files with 25 additions and 1 deletions

View File

@@ -910,7 +910,7 @@ def test_arange(start, device='cuda'):
def test_masked_load_shared_memory(dtype, device='cuda'):
M = 32
N = 32
K = 8
K = 16
in1 = torch.rand((M, K), dtype=dtype, device=device)
in2 = torch.rand((K, N), dtype=dtype, device=device)

View File

@@ -342,6 +342,26 @@ class constexpr:
def __bool__(self):
return bool(self.value)
def __ge__(self, other):
other = other.value if isinstance(other, constexpr) else other
return self.value >= other
def __gt__(self, other):
other = other.value if isinstance(other, constexpr) else other
return self.value > other
def __le__(self, other):
other = other.value if isinstance(other, constexpr) else other
return self.value <= other
def __lt__(self, other):
other = other.value if isinstance(other, constexpr) else other
return self.value < other
def __eq__(self, other):
other = other.value if isinstance(other, constexpr) else other
return self.value == other
def __call__(self, *args, **kwds):
return self.value(*args, **kwds)

View File

@@ -905,6 +905,10 @@ def dot(lhs: tl.tensor,
allow_tf32: bool,
builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
assert lhs.shape[-1] == rhs.shape[0]
assert lhs.shape[0] >= 16 and lhs.shape[1] >= 16 and rhs.shape[1] >= 16,\
"small blocks not supported!"
if lhs.type.scalar.is_int():
_0 = builder.get_int32(0)
ret_scalar_ty = tl.int32