[Triton-MLIR][CI] Fix v100 tests to avoid skiping tests mistakely (#975)

This commit is contained in:
Yan Chunwei
2022-12-11 12:57:51 +08:00
committed by GitHub
parent be2f70699c
commit 4fb048873a

View File

@@ -43,8 +43,7 @@ def matmul_no_scf_kernel(
for trans_b in [False, True] for trans_b in [False, True]
]) ])
def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B): def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
if not valid_on_Volta(NUM_WARPS, TRANS_A, TRANS_B): guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B)
pytest.skip("Not valid on Volta")
SIZE_M, SIZE_N, SIZE_K = SHAPE SIZE_M, SIZE_N, SIZE_K = SHAPE
if (TRANS_A): if (TRANS_A):
@@ -84,8 +83,7 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
for trans_b in [False, True] for trans_b in [False, True]
]) ])
def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B): def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
if not valid_on_Volta(NUM_WARPS, TRANS_A, TRANS_B, is_int8=True): guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B, is_int8=True)
pytest.skip("Not valid on Volta")
SIZE_M, SIZE_N, SIZE_K = SHAPE SIZE_M, SIZE_N, SIZE_K = SHAPE
@@ -201,8 +199,7 @@ def get_proper_err(a, b, golden):
[128, 64, 128, 4, 128, 64, 32, False, True], [128, 64, 128, 4, 128, 64, 32, False, True],
]) ])
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B): def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
if not valid_on_Volta(NUM_WARPS, TRANS_A, TRANS_B): guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B)
pytest.skip("Not valid on Volta")
if (TRANS_A): if (TRANS_A):
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
@@ -279,8 +276,7 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, c_mask) tl.store(c_ptrs, accumulator, c_mask)
if not valid_on_Volta(num_warps, trans_a=False, trans_b=False, is_tf32=allow_tf32): guard_for_volta(num_warps, trans_a=False, trans_b=False, is_tf32=allow_tf32)
pytest.skip("Not valid on Volta")
# Configure the pytorch counterpart # Configure the pytorch counterpart
torch.backends.cuda.matmul.allow_tf32 = allow_tf32 torch.backends.cuda.matmul.allow_tf32 = allow_tf32
@@ -306,16 +302,17 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err)) torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
def valid_on_Volta(num_warps, trans_a, trans_b, is_int8=False, is_tf32=False): def guard_for_volta(num_warps, trans_a, trans_b, is_int8=False, is_tf32=False):
''' '''
Tell whether the test case is valid on Volta GPU. Tell whether the test case is valid on Volta GPU.
Some features are WIP, so the corresponding support are missing. Some features are WIP, so the corresponding support are missing.
''' '''
if is_int8 or is_tf32:
return False
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
is_on_Volta = capability[0] < 8 is_on_Volta = capability[0] < 8
# TODO[Superjomn]: Remove the constraints below when features are ready # TODO[Superjomn]: Remove the constraints below when features are ready
is_feature_supported = not (is_int8 or is_tf32)
is_feature_ready = not (trans_a or trans_b) is_feature_ready = not (trans_a or trans_b)
return is_on_Volta and is_feature_ready
if is_on_Volta:
if (not is_feature_supported) or (not is_feature_ready):
pytest.skip("Not valid on Volta")