fix minor bug
This commit is contained in:
@@ -104,9 +104,12 @@ def check_type_supported(dtype):
|
|||||||
'''
|
'''
|
||||||
skip test if dtype is not supported on the current device
|
skip test if dtype is not supported on the current device
|
||||||
'''
|
'''
|
||||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
if torch.version.hip is not None:
|
||||||
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
pass
|
||||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
else:
|
||||||
|
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||||
|
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||||
|
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
|
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
|
||||||
|
Reference in New Issue
Block a user