[FRONTEND] Added a few assertions in semantic.dot
(#977)
This commit is contained in:
@@ -72,7 +72,7 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
|||||||
for shape in [
|
for shape in [
|
||||||
[64, 128, 128],
|
[64, 128, 128],
|
||||||
[128, 128, 128],
|
[128, 128, 128],
|
||||||
[16, 8, 32],
|
[16, 16, 32],
|
||||||
[32, 16, 64],
|
[32, 16, 64],
|
||||||
[32, 16, 64],
|
[32, 16, 64],
|
||||||
]
|
]
|
||||||
|
@@ -983,6 +983,11 @@ def dot(lhs: tl.tensor,
|
|||||||
allow_tf32: bool,
|
allow_tf32: bool,
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
assert lhs.type.is_block() and rhs.type.is_block()
|
assert lhs.type.is_block() and rhs.type.is_block()
|
||||||
|
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
|
||||||
|
assert lhs.shape[1].value == rhs.shape[0].value
|
||||||
|
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
|
||||||
|
and rhs.shape[1].value >= 16,\
|
||||||
|
"small blocks not supported!"
|
||||||
if lhs.type.scalar.is_int():
|
if lhs.type.scalar.is_int():
|
||||||
_0 = builder.get_int32(0)
|
_0 = builder.get_int32(0)
|
||||||
ret_scalar_ty = tl.int32
|
ret_scalar_ty = tl.int32
|
||||||
|
Reference in New Issue
Block a user