[FRONTEND] Better dot error message (#531)
This commit is contained in:
@@ -910,7 +910,7 @@ def test_arange(start, device='cuda'):
|
||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
M = 32
|
||||
N = 32
|
||||
K = 8
|
||||
K = 16
|
||||
|
||||
in1 = torch.rand((M, K), dtype=dtype, device=device)
|
||||
in2 = torch.rand((K, N), dtype=dtype, device=device)
|
||||
|
@@ -342,6 +342,26 @@ class constexpr:
|
||||
def __bool__(self):
|
||||
return bool(self.value)
|
||||
|
||||
def __ge__(self, other):
|
||||
other = other.value if isinstance(other, constexpr) else other
|
||||
return self.value >= other
|
||||
|
||||
def __gt__(self, other):
|
||||
other = other.value if isinstance(other, constexpr) else other
|
||||
return self.value > other
|
||||
|
||||
def __le__(self, other):
|
||||
other = other.value if isinstance(other, constexpr) else other
|
||||
return self.value <= other
|
||||
|
||||
def __lt__(self, other):
|
||||
other = other.value if isinstance(other, constexpr) else other
|
||||
return self.value < other
|
||||
|
||||
def __eq__(self, other):
|
||||
other = other.value if isinstance(other, constexpr) else other
|
||||
return self.value == other
|
||||
|
||||
def __call__(self, *args, **kwds):
|
||||
return self.value(*args, **kwds)
|
||||
|
||||
|
@@ -905,6 +905,10 @@ def dot(lhs: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
|
||||
assert lhs.shape[-1] == rhs.shape[0]
|
||||
assert lhs.shape[0] >= 16 and lhs.shape[1] >= 16 and rhs.shape[1] >= 16,\
|
||||
"small blocks not supported!"
|
||||
if lhs.type.scalar.is_int():
|
||||
_0 = builder.get_int32(0)
|
||||
ret_scalar_ty = tl.int32
|
||||
|
Reference in New Issue
Block a user