From e5cfa0f63399555e690018e0171276287502d575 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 12 Dec 2022 00:07:14 -0800 Subject: [PATCH] [FRONTEND] Added a few assertions in `semantic.dot` (#977) --- python/tests/test_gemm.py | 2 +- python/triton/language/semantic.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 7d502328e..40e91a588 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -72,7 +72,7 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B): for shape in [ [64, 128, 128], [128, 128, 128], - [16, 8, 32], + [16, 16, 32], [32, 16, 64], [32, 16, 64], ] diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 2c0e8a475..95db5387f 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -983,6 +983,11 @@ def dot(lhs: tl.tensor, allow_tf32: bool, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() + assert len(lhs.shape) == 2 and len(rhs.shape) == 2 + assert lhs.shape[1].value == rhs.shape[0].value + assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \ + and rhs.shape[1].value >= 16,\ + "small blocks not supported!" if lhs.type.scalar.is_int(): _0 = builder.get_int32(0) ret_scalar_ty = tl.int32