[BACKEND] Refine v100 tests and fix mmav1 numwarps>1 hang issue (#971)
This PR - Fix numWarps>1 hang issue - add existing test cases in test_gemm.py to CI, and add a common flag `valid_on_Volta` to determine whether the test case should be activated on Volta or just skip. - Currently, the column-major cases are disabled. - Add test_core.py and other tests to Volta CI - the `test_printf.py` failed.
This commit is contained in:
9
.github/workflows/integration-tests.yml
vendored
9
.github/workflows/integration-tests.yml
vendored
@@ -89,7 +89,14 @@ jobs:
|
||||
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
|
||||
run: |
|
||||
cd python/tests
|
||||
pytest test_gemm.py::test_gemm_for_mmav1
|
||||
pytest -k "not test_where_broadcast and not test_dot" test_core.py
|
||||
pytest test_gemm.py
|
||||
pytest test_backend.py
|
||||
pytest test_reduce.py
|
||||
pytest test_vecadd.py
|
||||
pytest test_elementwise.py
|
||||
pytest test_ext_elemwise.py
|
||||
pytest test_transpose.py
|
||||
|
||||
- name: Run CXX unittests
|
||||
run: |
|
||||
|
@@ -617,13 +617,15 @@ SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
|
||||
bool changed = false;
|
||||
do {
|
||||
changed = false;
|
||||
int pre = ret[0];
|
||||
if (ret[0] * ret[1] < numWarps) {
|
||||
ret[0] = std::clamp<unsigned>(ret[0] * 2, 1, shape[0] / shapePerWarp[0]);
|
||||
changed = true;
|
||||
changed = pre != ret[0];
|
||||
}
|
||||
if (ret[0] * ret[1] < numWarps) {
|
||||
pre = ret[1];
|
||||
ret[1] = std::clamp<unsigned>(ret[1] * 2, 1, shape[1] / shapePerWarp[1]);
|
||||
changed = true;
|
||||
changed = pre != ret[1];
|
||||
}
|
||||
} while (changed);
|
||||
return ret;
|
||||
|
@@ -43,6 +43,9 @@ def matmul_no_scf_kernel(
|
||||
for trans_b in [False, True]
|
||||
])
|
||||
def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
if not valid_on_Volta(NUM_WARPS, TRANS_A, TRANS_B):
|
||||
pytest.skip("Not valid on Volta")
|
||||
|
||||
SIZE_M, SIZE_N, SIZE_K = SHAPE
|
||||
if (TRANS_A):
|
||||
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
|
||||
@@ -81,6 +84,9 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
for trans_b in [False, True]
|
||||
])
|
||||
def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
if not valid_on_Volta(NUM_WARPS, TRANS_A, TRANS_B, is_int8=True):
|
||||
pytest.skip("Not valid on Volta")
|
||||
|
||||
SIZE_M, SIZE_N, SIZE_K = SHAPE
|
||||
|
||||
if (TRANS_A):
|
||||
@@ -195,6 +201,9 @@ def get_proper_err(a, b, golden):
|
||||
[128, 64, 128, 4, 128, 64, 32, False, True],
|
||||
])
|
||||
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||
if not valid_on_Volta(NUM_WARPS, TRANS_A, TRANS_B):
|
||||
pytest.skip("Not valid on Volta")
|
||||
|
||||
if (TRANS_A):
|
||||
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
@@ -270,6 +279,9 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, c_mask)
|
||||
|
||||
if not valid_on_Volta(num_warps, trans_a=False, trans_b=False, is_tf32=allow_tf32):
|
||||
pytest.skip("Not valid on Volta")
|
||||
|
||||
# Configure the pytorch counterpart
|
||||
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
||||
|
||||
@@ -294,44 +306,16 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
||||
|
||||
|
||||
# NOTE this is useful only on Volta GPU.
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
|
||||
# Non-forloop
|
||||
[16, 16, 16, 1, 16, 16, 16, False, False],
|
||||
[16, 16, 32, 1, 16, 16, 32, False, False],
|
||||
[32, 16, 32, 1, 32, 16, 32, False, False],
|
||||
[32, 32, 32, 1, 32, 32, 32, False, False],
|
||||
[128, 32, 32, 1, 128, 32, 32, False, False],
|
||||
# # split-K
|
||||
[16, 16, 32, 1, 16, 16, 16, False, False],
|
||||
[64, 64, 128, 1, 64, 64, 32, False, False],
|
||||
# numWarps > 1
|
||||
[32, 32, 64, 2, 32, 32, 32, False, False],
|
||||
[64, 32, 64, 4, 64, 32, 64, False, False],
|
||||
[128, 64, 128, 4, 128, 64, 128, False, False],
|
||||
# [16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue, hang on Volta
|
||||
# K-Forloop
|
||||
# [16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads
|
||||
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
|
||||
# [16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k, hang on Volta
|
||||
[64, 32, 128, 4, 64, 32, 64, False, False],
|
||||
[128, 16, 128, 4, 128, 16, 32, False, False],
|
||||
# [32, 16, 128, 4, 32, 16, 32, False, False], # hang on Volta
|
||||
[32, 64, 128, 4, 32, 64, 32, False, False],
|
||||
[32, 128, 256, 4, 32, 128, 64, False, False],
|
||||
[64, 128, 64, 4, 64, 128, 32, False, False],
|
||||
[64, 64, 128, 4, 64, 64, 32, False, False],
|
||||
[128, 128, 64, 4, 128, 128, 32, False, False],
|
||||
[128, 128, 128, 4, 128, 128, 32, False, False],
|
||||
[128, 128, 256, 4, 128, 128, 64, False, False],
|
||||
[128, 256, 128, 4, 128, 256, 32, False, False],
|
||||
[256, 128, 64, 4, 256, 128, 16, False, False],
|
||||
[128, 64, 128, 4, 128, 64, 32, False, False],
|
||||
# [16, 16, 64, 4, 16, 16, 16, False, False], # hang on Volta
|
||||
[32, 32, 64, 4, 32, 32, 32, False, False],
|
||||
# trans
|
||||
# [128, 64, 128, 4, 128, 64, 32, True, False],
|
||||
# [128, 64, 128, 4, 128, 64, 32, False, True],
|
||||
])
|
||||
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
|
||||
def valid_on_Volta(num_warps, trans_a, trans_b, is_int8=False, is_tf32=False):
|
||||
'''
|
||||
Tell whether the test case is valid on Volta GPU.
|
||||
Some features are WIP, so the corresponding support are missing.
|
||||
'''
|
||||
if is_int8 or is_tf32:
|
||||
return False
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
is_on_Volta = capability[0] < 8
|
||||
# TODO[Superjomn]: Remove the constraints below when features are ready
|
||||
is_feature_ready = not (trans_a or trans_b)
|
||||
return is_on_Volta and is_feature_ready
|
||||
|
Reference in New Issue
Block a user