[FRONTEND] Added a few assertions in semantic.dot (#977)

This commit is contained in:
Philippe Tillet
2022-12-12 00:07:14 -08:00
committed by GitHub
parent e552219104
commit e5cfa0f633
2 changed files with 6 additions and 1 deletions

View File

@@ -72,7 +72,7 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
for shape in [
[64, 128, 128],
[128, 128, 128],
[16, 8, 32],
[16, 16, 32],
[32, 16, 64],
[32, 16, 64],
]

View File

@@ -983,6 +983,11 @@ def dot(lhs: tl.tensor,
allow_tf32: bool,
builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
assert lhs.shape[1].value == rhs.shape[0].value
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
and rhs.shape[1].value >= 16,\
"small blocks not supported!"
if lhs.type.scalar.is_int():
_0 = builder.get_int32(0)
ret_scalar_ty = tl.int32