Merge pull request #19 from ROCmSoftwarePlatform/unskip_test_reduce
reduce the skips for test_reduce functions
This commit is contained in:
@@ -892,7 +892,8 @@ def test_f16_to_f8_rounding():
|
|||||||
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
pytest.skip(f"test_reduce1d currently has segfaults on ROCM")
|
if dtype_str in ["int8", "int16", "uint8", "uint16"]:
|
||||||
|
pytest.skip(f"test_reduce1d[{dtype_str}] skipped on ROCM")
|
||||||
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -953,7 +954,8 @@ reduce_configs2 = [
|
|||||||
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
pytest.skip(f"test_reduce2d currently has segfaults on ROCM")
|
if dtype_str in ["int8", "int16", "uint8", "uint16"]:
|
||||||
|
pytest.skip(f"test_reduce2d[{dtype_str}] skipped on ROCM")
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||||
|
Reference in New Issue
Block a user