Merge pull request #15 from ROCmSoftwarePlatform/bfloat_enable

unskip most bfloat tests
This commit is contained in:
rsanthanam-amd
2022-10-31 13:10:30 -05:00
committed by GitHub
2 changed files with 13 additions and 8 deletions

View File

@@ -104,13 +104,9 @@ def check_type_supported(dtype):
'''
skip test if dtype is not supported on the current device
'''
if torch.version.hip is not None:
if dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16:
pytest.skip("bfloat16 is not supported on AMDGPU")
else:
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
@@ -127,6 +123,9 @@ def test_empty_kernel(dtype_x, device='cuda'):
# generic test functions
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
if torch.version.hip is not None:
if dtype_x == "bfloat16":
pytest.skip("unary op with bfloat is not supported on AMDGPU")
check_type_supported(dtype_x) # early return if dtype_x is not supported
SIZE = 128
# define the kernel / launch-grid
@@ -256,6 +255,10 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
for dtype_y in dtypes_with_bfloat16
])
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
if torch.version.hip is not None:
if dtype_x == "bfloat16" and dtype_y == "bfloat16" :
pytest.skip("binary op with bfloat is not supported on AMDGPU")
expr = f' x {op} y'
if op == '%' and (dtype_x in dtypes and dtype_y in dtypes):
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.

View File

@@ -26,12 +26,14 @@ rm -rf /tmp/triton
# python python/test/test_empty.py
# -ex 'ignore 1 472' \
pytest --verbose python/test/unit/language/test_core.py 2>&1 | tee /dockerx/triton/test_core.log
pytest -rs --verbose python/test/unit/language/test_core.py 2>&1 | tee /dockerx/triton/test_core.log
# pytest --verbose python/test/unit/language/test_core.py::test_empty_kernel[float32] 2>&1 | tee /dockerx/triton/test_empty_kernel.log
# pytest --verbose python/test/unit/language/test_core.py::test_bin_op[int32-uint32-+] 2>&1 | tee /dockerx/triton/test_bin_op.log
# pytest --verbose python/test/unit/language/test_core.py::test_atomic_rmw 2>&1 | tee /dockerx/triton/test_atomic_rmw.log
# pytest --verbose python/test/unit/language/test_core.py::test_atomic_rmw[add-float16-all_neg] 2>&1 | tee /dockerx/triton/test_atomic_rmw.log
# pytest --verbose "python/test/unit/language/test_core.py::test_reduce1d" 2>&1 | tee /dockerx/triton/test_reduce1d.log
# pytest --verbose "python/test/unit/language/test_core.py::test_cast[float32-float32-False]" 2>&1 | tee /dockerx/triton/test_cast.log
# pytest --verbose "python/test/unit/language/test_core.py::test_load_cache_modifier" 2>&1 | tee /dockerx/triton/test_vectorization.log
# mismatch
# pytest --verbose "python/test/unit/language/test_core.py::test_bin_op[int8-float16-%]"" 2>&1 | tee /dockerx/triton/test_bin_op.log