[FRONTEND] Added a few assertions in semantic.dot
(#977)
This commit is contained in:
@@ -72,7 +72,7 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
for shape in [
|
||||
[64, 128, 128],
|
||||
[128, 128, 128],
|
||||
[16, 8, 32],
|
||||
[16, 16, 32],
|
||||
[32, 16, 64],
|
||||
[32, 16, 64],
|
||||
]
|
||||
|
@@ -983,6 +983,11 @@ def dot(lhs: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
|
||||
assert lhs.shape[1].value == rhs.shape[0].value
|
||||
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 16,\
|
||||
"small blocks not supported!"
|
||||
if lhs.type.scalar.is_int():
|
||||
_0 = builder.get_int32(0)
|
||||
ret_scalar_ty = tl.int32
|
||||
|
Reference in New Issue
Block a user