diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 00241403c..cad10c871 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -892,7 +892,8 @@ def test_f16_to_f8_rounding(): def test_reduce1d(op, dtype_str, shape, device='cuda'): check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested if torch.version.hip is not None: - pytest.skip(f"test_reduce1d currently has segfaults on ROCM") + if dtype_str in ["int8", "int16", "uint8", "uint16"]: + pytest.skip(f"test_reduce1d[{dtype_str}] skipped on ROCM") # triton kernel @triton.jit @@ -953,7 +954,8 @@ reduce_configs2 = [ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested if torch.version.hip is not None: - pytest.skip(f"test_reduce2d currently has segfaults on ROCM") + if dtype_str in ["int8", "int16", "uint8", "uint16"]: + pytest.skip(f"test_reduce2d[{dtype_str}] skipped on ROCM") # triton kernel @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):