diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index b2c4a0516..3de303c30 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -104,13 +104,9 @@ def check_type_supported(dtype): ''' skip test if dtype is not supported on the current device ''' - if torch.version.hip is not None: - if dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16: - pytest.skip("bfloat16 is not supported on AMDGPU") - else: - cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) - if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): - pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): + pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") @pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"]) @@ -127,6 +123,9 @@ def test_empty_kernel(dtype_x, device='cuda'): # generic test functions def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): + if torch.version.hip is not None: + if dtype_x == "bfloat16": + pytest.skip("unary op with bfloat is not supported on AMDGPU") check_type_supported(dtype_x) # early return if dtype_x is not supported SIZE = 128 # define the kernel / launch-grid @@ -256,6 +255,10 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: for dtype_y in dtypes_with_bfloat16 ]) def test_bin_op(dtype_x, dtype_y, op, device='cuda'): + if torch.version.hip is not None: + if dtype_x == "bfloat16" and dtype_y == "bfloat16" : + pytest.skip("binary op with bfloat is not supported on AMDGPU") + expr = f' x {op} y' if op == '%' and (dtype_x in dtypes and dtype_y in dtypes): # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. diff --git a/scripts/amd/test.sh b/scripts/amd/test.sh index d6c1d9ad5..c9398099f 100755 --- a/scripts/amd/test.sh +++ b/scripts/amd/test.sh @@ -26,12 +26,14 @@ rm -rf /tmp/triton # python python/test/test_empty.py # -ex 'ignore 1 472' \ -pytest --verbose python/test/unit/language/test_core.py 2>&1 | tee /dockerx/triton/test_core.log +pytest -rs --verbose python/test/unit/language/test_core.py 2>&1 | tee /dockerx/triton/test_core.log # pytest --verbose python/test/unit/language/test_core.py::test_empty_kernel[float32] 2>&1 | tee /dockerx/triton/test_empty_kernel.log # pytest --verbose python/test/unit/language/test_core.py::test_bin_op[int32-uint32-+] 2>&1 | tee /dockerx/triton/test_bin_op.log # pytest --verbose python/test/unit/language/test_core.py::test_atomic_rmw 2>&1 | tee /dockerx/triton/test_atomic_rmw.log # pytest --verbose python/test/unit/language/test_core.py::test_atomic_rmw[add-float16-all_neg] 2>&1 | tee /dockerx/triton/test_atomic_rmw.log # pytest --verbose "python/test/unit/language/test_core.py::test_reduce1d" 2>&1 | tee /dockerx/triton/test_reduce1d.log +# pytest --verbose "python/test/unit/language/test_core.py::test_cast[float32-float32-False]" 2>&1 | tee /dockerx/triton/test_cast.log +# pytest --verbose "python/test/unit/language/test_core.py::test_load_cache_modifier" 2>&1 | tee /dockerx/triton/test_vectorization.log # mismatch # pytest --verbose "python/test/unit/language/test_core.py::test_bin_op[int8-float16-%]"" 2>&1 | tee /dockerx/triton/test_bin_op.log