print irs
This commit is contained in:
@@ -113,14 +113,14 @@ def check_type_supported(dtype):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
|
||||
def test_empty_kernel(dtype_x, device='cuda'):
|
||||
SIZE = 128
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, SIZE: tl.constexpr):
|
||||
pass
|
||||
# check_type_supported(dtype_x)
|
||||
check_type_supported(dtype_x)
|
||||
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x)
|
||||
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
|
||||
|
||||
@@ -885,9 +885,9 @@ def test_f16_to_f8_rounding():
|
||||
for dtype in dtypes_with_bfloat16
|
||||
for shape in [32, 64, 128, 512]])
|
||||
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip(f"test_reduce1d currently has segfaults on ROCM")
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -946,6 +946,7 @@ reduce_configs2 = [
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
||||
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip(f"test_reduce2d currently has segfaults on ROCM")
|
||||
# triton kernel
|
||||
|
Reference in New Issue
Block a user