fix bfloat failure

This commit is contained in:
Michael Melesse
2022-10-26 17:40:28 +00:00
parent 88d57ef9c9
commit 0cae0168ec
3 changed files with 10 additions and 6 deletions

View File

@@ -104,9 +104,13 @@ def check_type_supported(dtype):
'''
skip test if dtype is not supported on the current device
'''
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
if torch.version.hip is not None:
if dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16:
pytest.skip("bfloat16 is not supported on AMDGPU")
else:
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])