testing things...
This commit is contained in:
@@ -297,22 +297,22 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
# NOTE this is useful only on Volta GPU.
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
|
||||
# Non-forloop
|
||||
[16, 16, 16, 1, 16, 16, 16, False, False],
|
||||
[16, 16, 32, 1, 16, 16, 32, False, False],
|
||||
[32, 16, 32, 1, 32, 16, 32, False, False],
|
||||
[32, 32, 32, 1, 32, 32, 32, False, False],
|
||||
[128, 32, 32, 1, 128, 32, 32, False, False],
|
||||
# [16, 16, 16, 1, 16, 16, 16, False, False],
|
||||
# [16, 16, 32, 1, 16, 16, 32, False, False],
|
||||
# [32, 16, 32, 1, 32, 16, 32, False, False],
|
||||
# [32, 32, 32, 1, 32, 32, 32, False, False],
|
||||
# [128, 32, 32, 1, 128, 32, 32, False, False],
|
||||
|
||||
[128, 32, 32, 1, 128, 32, 32, True, False],
|
||||
[128, 32, 32, 1, 128, 32, 32, True, True],
|
||||
# [128, 32, 32, 1, 128, 32, 32, True, False],
|
||||
# [128, 32, 32, 1, 128, 32, 32, True, True],
|
||||
|
||||
# split-K
|
||||
[16, 16, 32, 1, 16, 16, 16, False, False],
|
||||
[64, 64, 128, 1, 64, 64, 32, False, False],
|
||||
# # split-K
|
||||
# [16, 16, 32, 1, 16, 16, 16, False, False],
|
||||
# [64, 64, 128, 1, 64, 64, 32, False, False],
|
||||
|
||||
[16, 16, 32, 1, 16, 16, 16, True, False],
|
||||
[16, 16, 32, 1, 16, 16, 16, True, True],
|
||||
[64, 64, 128, 1, 64, 64, 32, True, True],
|
||||
# [16, 16, 32, 1, 16, 16, 16, True, False],
|
||||
# [16, 16, 32, 1, 16, 16, 16, True, True],
|
||||
[64, 64, 64, 1, 64, 64, 32, True, False],
|
||||
])
|
||||
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
|
||||
|
Reference in New Issue
Block a user