[Triton-MLIR][Backend] Fix mmav1 in case of numWarps > 1 (#972)

This commit is contained in:
goostavz
2022-12-09 18:36:05 +08:00
committed by GitHub
parent 3ed36dcb4d
commit 793012b4c4
2 changed files with 32 additions and 2 deletions

View File

@@ -302,10 +302,36 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
[32, 16, 32, 1, 32, 16, 32, False, False],
[32, 32, 32, 1, 32, 32, 32, False, False],
[128, 32, 32, 1, 128, 32, 32, False, False],
# split-K
# # split-K
[16, 16, 32, 1, 16, 16, 16, False, False],
[64, 64, 128, 1, 64, 64, 32, False, False],
# numWarps > 1
[32, 32, 64, 2, 32, 32, 32, False, False],
[64, 32, 64, 4, 64, 32, 64, False, False],
[128, 64, 128, 4, 128, 64, 128, False, False],
# [16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue, hang on Volta
# K-Forloop
# [16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
# [16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k, hang on Volta
[64, 32, 128, 4, 64, 32, 64, False, False],
[128, 16, 128, 4, 128, 16, 32, False, False],
# [32, 16, 128, 4, 32, 16, 32, False, False], # hang on Volta
[32, 64, 128, 4, 32, 64, 32, False, False],
[32, 128, 256, 4, 32, 128, 64, False, False],
[64, 128, 64, 4, 64, 128, 32, False, False],
[64, 64, 128, 4, 64, 64, 32, False, False],
[128, 128, 64, 4, 128, 128, 32, False, False],
[128, 128, 128, 4, 128, 128, 32, False, False],
[128, 128, 256, 4, 128, 128, 64, False, False],
[128, 256, 128, 4, 128, 256, 32, False, False],
[256, 128, 64, 4, 256, 128, 16, False, False],
[128, 64, 128, 4, 128, 64, 32, False, False],
# [16, 16, 64, 4, 16, 16, 16, False, False], # hang on Volta
[32, 32, 64, 4, 32, 32, 32, False, False],
# trans
# [128, 64, 128, 4, 128, 64, 32, True, False],
# [128, 64, 128, 4, 128, 64, 32, False, True],
])
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)