From 9b3f2487b5fb00475bbfb3e03a41c14981fd39e0 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Mon, 31 Oct 2022 18:33:47 +0000 Subject: [PATCH] fix minor bug --- python/test/unit/language/test_core.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3de303c30..8a695246c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -104,9 +104,12 @@ def check_type_supported(dtype): ''' skip test if dtype is not supported on the current device ''' - cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) - if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): - pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") + if torch.version.hip is not None: + pass + else: + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): + pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") @pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])