adapt isMMAv1Row in backend (#969)

This commit is contained in:
Yan Chunwei
2022-12-09 12:52:43 +08:00
committed by GitHub
parent 0d27912554
commit 13644e7ac4
3 changed files with 28 additions and 12 deletions

View File

@@ -303,9 +303,17 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
[32, 32, 32, 1, 32, 32, 32, False, False],
[128, 32, 32, 1, 128, 32, 32, False, False],
[128, 32, 32, 1, 128, 32, 32, True, False],
[128, 32, 32, 1, 128, 32, 32, True, True],
# split-K
[16, 16, 32, 1, 16, 16, 16, False, False],
[64, 64, 128, 1, 64, 64, 32, False, False],
[16, 16, 32, 1, 16, 16, 16, True, False],
[16, 16, 32, 1, 16, 16, 16, True, True],
[64, 64, 128, 1, 64, 64, 32, True, True],
])
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)