From dfad6bdf368e665afd3ae4c020635b7aad43d98f Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 1 Nov 2022 15:00:12 +0000 Subject: [PATCH] reduce the skips for test_reduce functions --- python/test/unit/language/test_core.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 00241403c..cad10c871 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -892,7 +892,8 @@ def test_f16_to_f8_rounding(): def test_reduce1d(op, dtype_str, shape, device='cuda'): check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested if torch.version.hip is not None: - pytest.skip(f"test_reduce1d currently has segfaults on ROCM") + if dtype_str in ["int8", "int16", "uint8", "uint16"]: + pytest.skip(f"test_reduce1d[{dtype_str}] skipped on ROCM") # triton kernel @triton.jit @@ -953,7 +954,8 @@ reduce_configs2 = [ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested if torch.version.hip is not None: - pytest.skip(f"test_reduce2d currently has segfaults on ROCM") + if dtype_str in ["int8", "int16", "uint8", "uint16"]: + pytest.skip(f"test_reduce2d[{dtype_str}] skipped on ROCM") # triton kernel @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):