[Triton-MLIR][CI] Fix v100 tests to avoid skiping tests mistakely (#975)
This commit is contained in:
@@ -43,8 +43,7 @@ def matmul_no_scf_kernel(
|
|||||||
for trans_b in [False, True]
|
for trans_b in [False, True]
|
||||||
])
|
])
|
||||||
def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||||
if not valid_on_Volta(NUM_WARPS, TRANS_A, TRANS_B):
|
guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B)
|
||||||
pytest.skip("Not valid on Volta")
|
|
||||||
|
|
||||||
SIZE_M, SIZE_N, SIZE_K = SHAPE
|
SIZE_M, SIZE_N, SIZE_K = SHAPE
|
||||||
if (TRANS_A):
|
if (TRANS_A):
|
||||||
@@ -84,8 +83,7 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
|||||||
for trans_b in [False, True]
|
for trans_b in [False, True]
|
||||||
])
|
])
|
||||||
def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||||
if not valid_on_Volta(NUM_WARPS, TRANS_A, TRANS_B, is_int8=True):
|
guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B, is_int8=True)
|
||||||
pytest.skip("Not valid on Volta")
|
|
||||||
|
|
||||||
SIZE_M, SIZE_N, SIZE_K = SHAPE
|
SIZE_M, SIZE_N, SIZE_K = SHAPE
|
||||||
|
|
||||||
@@ -201,8 +199,7 @@ def get_proper_err(a, b, golden):
|
|||||||
[128, 64, 128, 4, 128, 64, 32, False, True],
|
[128, 64, 128, 4, 128, 64, 32, False, True],
|
||||||
])
|
])
|
||||||
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||||
if not valid_on_Volta(NUM_WARPS, TRANS_A, TRANS_B):
|
guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B)
|
||||||
pytest.skip("Not valid on Volta")
|
|
||||||
|
|
||||||
if (TRANS_A):
|
if (TRANS_A):
|
||||||
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
|
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
|
||||||
@@ -279,8 +276,7 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
|||||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||||
tl.store(c_ptrs, accumulator, c_mask)
|
tl.store(c_ptrs, accumulator, c_mask)
|
||||||
|
|
||||||
if not valid_on_Volta(num_warps, trans_a=False, trans_b=False, is_tf32=allow_tf32):
|
guard_for_volta(num_warps, trans_a=False, trans_b=False, is_tf32=allow_tf32)
|
||||||
pytest.skip("Not valid on Volta")
|
|
||||||
|
|
||||||
# Configure the pytorch counterpart
|
# Configure the pytorch counterpart
|
||||||
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
||||||
@@ -306,16 +302,17 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
|||||||
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
||||||
|
|
||||||
|
|
||||||
def valid_on_Volta(num_warps, trans_a, trans_b, is_int8=False, is_tf32=False):
|
def guard_for_volta(num_warps, trans_a, trans_b, is_int8=False, is_tf32=False):
|
||||||
'''
|
'''
|
||||||
Tell whether the test case is valid on Volta GPU.
|
Tell whether the test case is valid on Volta GPU.
|
||||||
Some features are WIP, so the corresponding support are missing.
|
Some features are WIP, so the corresponding support are missing.
|
||||||
'''
|
'''
|
||||||
if is_int8 or is_tf32:
|
|
||||||
return False
|
|
||||||
|
|
||||||
capability = torch.cuda.get_device_capability()
|
capability = torch.cuda.get_device_capability()
|
||||||
is_on_Volta = capability[0] < 8
|
is_on_Volta = capability[0] < 8
|
||||||
# TODO[Superjomn]: Remove the constraints below when features are ready
|
# TODO[Superjomn]: Remove the constraints below when features are ready
|
||||||
|
is_feature_supported = not (is_int8 or is_tf32)
|
||||||
is_feature_ready = not (trans_a or trans_b)
|
is_feature_ready = not (trans_a or trans_b)
|
||||||
return is_on_Volta and is_feature_ready
|
|
||||||
|
if is_on_Volta:
|
||||||
|
if (not is_feature_supported) or (not is_feature_ready):
|
||||||
|
pytest.skip("Not valid on Volta")
|
||||||
|
Reference in New Issue
Block a user