[Triton-MLIR][BACKEND] make MMAv1 splitk works (#960)

This commit is contained in:
Yan Chunwei
2022-12-07 16:58:38 +08:00
committed by GitHub
parent b2b793dfb5
commit 4eab9dcedf
6 changed files with 58 additions and 24 deletions

View File

@@ -295,18 +295,17 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
# NOTE this is useful only on Volta GPU.
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
(shape, num_warps, trans_a, trans_b)
for shape in [
[16, 16, 16],
[16, 16, 32],
[32, 16, 16],
[32, 32, 32],
[128, 16, 16],
]
for num_warps in [1]
for trans_a in [False]
for trans_b in [False]
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
# Non-forloop
[16, 16, 16, 1, 16, 16, 16, False, False],
[16, 16, 32, 1, 16, 16, 32, False, False],
[32, 16, 32, 1, 32, 16, 32, False, False],
[32, 32, 32, 1, 32, 32, 32, False, False],
[128, 32, 32, 1, 128, 32, 32, False, False],
# split-K
[16, 16, 32, 1, 16, 16, 16, False, False],
[64, 64, 128, 1, 64, 64, 32, False, False],
])
def test_gemm_no_scf_for_mmav1(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B)
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)