From 24fd953f9a3d84b69d047b72f4069ded00daee6d Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Fri, 9 Dec 2022 23:41:22 +0800 Subject: [PATCH] [BACKEND] Refine v100 tests and fix mmav1 numwarps>1 hang issue (#971) This PR - Fix numWarps>1 hang issue - add existing test cases in test_gemm.py to CI, and add a common flag `valid_on_Volta` to determine whether the test case should be activated on Volta or just skip. - Currently, the column-major cases are disabled. - Add test_core.py and other tests to Volta CI - the `test_printf.py` failed. --- .github/workflows/integration-tests.yml | 9 ++- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 6 +- python/tests/test_gemm.py | 66 ++++++++------------ 3 files changed, 37 insertions(+), 44 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index f3110988a..1d7a48bf0 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -89,7 +89,14 @@ jobs: if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}} run: | cd python/tests - pytest test_gemm.py::test_gemm_for_mmav1 + pytest -k "not test_where_broadcast and not test_dot" test_core.py + pytest test_gemm.py + pytest test_backend.py + pytest test_reduce.py + pytest test_vecadd.py + pytest test_elementwise.py + pytest test_ext_elemwise.py + pytest test_transpose.py - name: Run CXX unittests run: | diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index e14bae003..5a4d19a46 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -617,13 +617,15 @@ SmallVector warpsPerTileV1(triton::DotOp dotOp, bool changed = false; do { changed = false; + int pre = ret[0]; if (ret[0] * ret[1] < numWarps) { ret[0] = std::clamp(ret[0] * 2, 1, shape[0] / shapePerWarp[0]); - changed = true; + changed = pre != ret[0]; } if (ret[0] * ret[1] < numWarps) { + pre = ret[1]; ret[1] = std::clamp(ret[1] * 2, 1, shape[1] / shapePerWarp[1]); - changed = true; + changed = pre != ret[1]; } } while (changed); return ret; diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index cd3a7f805..b2ce97cb2 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -43,6 +43,9 @@ def matmul_no_scf_kernel( for trans_b in [False, True] ]) def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B): + if not valid_on_Volta(NUM_WARPS, TRANS_A, TRANS_B): + pytest.skip("Not valid on Volta") + SIZE_M, SIZE_N, SIZE_K = SHAPE if (TRANS_A): a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T @@ -81,6 +84,9 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B): for trans_b in [False, True] ]) def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B): + if not valid_on_Volta(NUM_WARPS, TRANS_A, TRANS_B, is_int8=True): + pytest.skip("Not valid on Volta") + SIZE_M, SIZE_N, SIZE_K = SHAPE if (TRANS_A): @@ -195,6 +201,9 @@ def get_proper_err(a, b, golden): [128, 64, 128, 4, 128, 64, 32, False, True], ]) def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B): + if not valid_on_Volta(NUM_WARPS, TRANS_A, TRANS_B): + pytest.skip("Not valid on Volta") + if (TRANS_A): a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T else: @@ -270,6 +279,9 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32): c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, c_mask) + if not valid_on_Volta(num_warps, trans_a=False, trans_b=False, is_tf32=allow_tf32): + pytest.skip("Not valid on Volta") + # Configure the pytorch counterpart torch.backends.cuda.matmul.allow_tf32 = allow_tf32 @@ -294,44 +306,16 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32): torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err)) -# NOTE this is useful only on Volta GPU. -@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [ - # Non-forloop - [16, 16, 16, 1, 16, 16, 16, False, False], - [16, 16, 32, 1, 16, 16, 32, False, False], - [32, 16, 32, 1, 32, 16, 32, False, False], - [32, 32, 32, 1, 32, 32, 32, False, False], - [128, 32, 32, 1, 128, 32, 32, False, False], - # # split-K - [16, 16, 32, 1, 16, 16, 16, False, False], - [64, 64, 128, 1, 64, 64, 32, False, False], - # numWarps > 1 - [32, 32, 64, 2, 32, 32, 32, False, False], - [64, 32, 64, 4, 64, 32, 64, False, False], - [128, 64, 128, 4, 128, 64, 128, False, False], - # [16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue, hang on Volta - # K-Forloop - # [16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads - [32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding - # [16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k, hang on Volta - [64, 32, 128, 4, 64, 32, 64, False, False], - [128, 16, 128, 4, 128, 16, 32, False, False], - # [32, 16, 128, 4, 32, 16, 32, False, False], # hang on Volta - [32, 64, 128, 4, 32, 64, 32, False, False], - [32, 128, 256, 4, 32, 128, 64, False, False], - [64, 128, 64, 4, 64, 128, 32, False, False], - [64, 64, 128, 4, 64, 64, 32, False, False], - [128, 128, 64, 4, 128, 128, 32, False, False], - [128, 128, 128, 4, 128, 128, 32, False, False], - [128, 128, 256, 4, 128, 128, 64, False, False], - [128, 256, 128, 4, 128, 256, 32, False, False], - [256, 128, 64, 4, 256, 128, 16, False, False], - [128, 64, 128, 4, 128, 64, 32, False, False], - # [16, 16, 64, 4, 16, 16, 16, False, False], # hang on Volta - [32, 32, 64, 4, 32, 32, 32, False, False], - # trans - # [128, 64, 128, 4, 128, 64, 32, True, False], - # [128, 64, 128, 4, 128, 64, 32, False, True], -]) -def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B): - test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B) +def valid_on_Volta(num_warps, trans_a, trans_b, is_int8=False, is_tf32=False): + ''' + Tell whether the test case is valid on Volta GPU. + Some features are WIP, so the corresponding support are missing. + ''' + if is_int8 or is_tf32: + return False + + capability = torch.cuda.get_device_capability() + is_on_Volta = capability[0] < 8 + # TODO[Superjomn]: Remove the constraints below when features are ready + is_feature_ready = not (trans_a or trans_b) + return is_on_Volta and is_feature_ready