unskip most bfloat tests
This commit is contained in:
@@ -104,13 +104,9 @@ def check_type_supported(dtype):
|
||||
'''
|
||||
skip test if dtype is not supported on the current device
|
||||
'''
|
||||
if torch.version.hip is not None:
|
||||
if dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16:
|
||||
pytest.skip("bfloat16 is not supported on AMDGPU")
|
||||
else:
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
|
||||
@@ -127,6 +123,9 @@ def test_empty_kernel(dtype_x, device='cuda'):
|
||||
|
||||
# generic test functions
|
||||
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
||||
if torch.version.hip is not None:
|
||||
if dtype_x == "bfloat16":
|
||||
pytest.skip("unary op with bfloat is not supported on AMDGPU")
|
||||
check_type_supported(dtype_x) # early return if dtype_x is not supported
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@@ -256,6 +255,10 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
for dtype_y in dtypes_with_bfloat16
|
||||
])
|
||||
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
if torch.version.hip is not None:
|
||||
if dtype_x == "bfloat16" and dtype_y == "bfloat16" :
|
||||
pytest.skip("binary op with bfloat is not supported on AMDGPU")
|
||||
|
||||
expr = f' x {op} y'
|
||||
if op == '%' and (dtype_x in dtypes and dtype_y in dtypes):
|
||||
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
|
||||
|
@@ -26,12 +26,14 @@ rm -rf /tmp/triton
|
||||
# python python/test/test_empty.py
|
||||
# -ex 'ignore 1 472' \
|
||||
|
||||
pytest --verbose python/test/unit/language/test_core.py 2>&1 | tee /dockerx/triton/test_core.log
|
||||
pytest -rs --verbose python/test/unit/language/test_core.py 2>&1 | tee /dockerx/triton/test_core.log
|
||||
# pytest --verbose python/test/unit/language/test_core.py::test_empty_kernel[float32] 2>&1 | tee /dockerx/triton/test_empty_kernel.log
|
||||
# pytest --verbose python/test/unit/language/test_core.py::test_bin_op[int32-uint32-+] 2>&1 | tee /dockerx/triton/test_bin_op.log
|
||||
# pytest --verbose python/test/unit/language/test_core.py::test_atomic_rmw 2>&1 | tee /dockerx/triton/test_atomic_rmw.log
|
||||
# pytest --verbose python/test/unit/language/test_core.py::test_atomic_rmw[add-float16-all_neg] 2>&1 | tee /dockerx/triton/test_atomic_rmw.log
|
||||
# pytest --verbose "python/test/unit/language/test_core.py::test_reduce1d" 2>&1 | tee /dockerx/triton/test_reduce1d.log
|
||||
# pytest --verbose "python/test/unit/language/test_core.py::test_cast[float32-float32-False]" 2>&1 | tee /dockerx/triton/test_cast.log
|
||||
# pytest --verbose "python/test/unit/language/test_core.py::test_load_cache_modifier" 2>&1 | tee /dockerx/triton/test_vectorization.log
|
||||
|
||||
# mismatch
|
||||
# pytest --verbose "python/test/unit/language/test_core.py::test_bin_op[int8-float16-%]"" 2>&1 | tee /dockerx/triton/test_bin_op.log
|
||||
|
Reference in New Issue
Block a user