[Triton-MLIR][BACKEND] Fix wpt overflow issue in mma v2 (#904)

This PR

1. Fix wpt overflow issue in mma v2
2. Refine transpose logic
This commit is contained in:
Yan Chunwei
2022-11-23 11:27:15 +08:00
committed by GitHub
parent 07786dc932
commit 037f9efa95
2 changed files with 64 additions and 37 deletions

View File

@@ -27,8 +27,6 @@ def matmul_no_scf_kernel(
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, c)
# TODO: num_warps could only be 4 for now
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
(shape, num_warps, trans_a, trans_b)
@@ -172,6 +170,7 @@ def get_proper_err(a, b, golden):
# Non-forloop
[64, 32, 64, 4, 64, 32, 64, False, False],
[128, 64, 128, 4, 128, 64, 128, False, False],
[16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue
# K-Forloop
[64, 32, 128, 4, 64, 32, 64, False, False],
[128, 16, 128, 4, 128, 16, 32, False, False],
@@ -186,6 +185,7 @@ def get_proper_err(a, b, golden):
[128, 256, 128, 4, 128, 256, 32, False, False],
[256, 128, 64, 4, 256, 128, 16, False, False],
[128, 64, 128, 4, 128, 64, 32, False, False],
# [16, 16, 64, 4, 16, 16, 16, False, False], # TODO failed due to pipeline pass
# trans
[128, 64, 128, 4, 128, 64, 32, True, False],
[128, 64, 128, 4, 128, 64, 32, False, True],