[BACKEND] Better bf16 support (#588)
This commit is contained in:
@@ -33,27 +33,37 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h
|
||||
shape = (shape, )
|
||||
if rs is None:
|
||||
rs = RandomState(seed=17)
|
||||
dtype = getattr(np, dtype_str)
|
||||
if dtype_str in int_dtypes + uint_dtypes:
|
||||
iinfo = np.iinfo(getattr(np, dtype_str))
|
||||
low = iinfo.min if low is None else max(low, iinfo.min)
|
||||
high = iinfo.max if high is None else min(high, iinfo.max)
|
||||
dtype = getattr(np, dtype_str)
|
||||
x = rs.randint(low, high, shape, dtype=dtype)
|
||||
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
|
||||
return x
|
||||
elif dtype_str in float_dtypes:
|
||||
return rs.normal(0, 1, shape).astype(dtype)
|
||||
return rs.normal(0, 1, shape).astype(dtype_str)
|
||||
elif dtype_str == 'bfloat16':
|
||||
return (rs.normal(0, 1, shape).astype('float32').view('uint32')
|
||||
& np.uint32(0xffff0000)).view('float32')
|
||||
else:
|
||||
raise RuntimeError(f'Unknown dtype {dtype_str}')
|
||||
|
||||
|
||||
def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]:
|
||||
def to_triton(x: np.ndarray, device='cuda', dst_type=None) -> Union[TensorWrapper, torch.Tensor]:
|
||||
'''
|
||||
Note: We need dst_type becasue the type of x can be different from dst_type.
|
||||
For example: x is of type `float32`, dst_type is `bfloat16`.
|
||||
If dst_type is None, we infer dst_type from x.
|
||||
'''
|
||||
t = x.dtype.name
|
||||
if t in uint_dtypes:
|
||||
signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16"
|
||||
x_signed = x.astype(getattr(np, signed_type_name))
|
||||
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
|
||||
else:
|
||||
if t == 'float32' and dst_type == 'bfloat16':
|
||||
return torch.tensor(x, device=device).bfloat16()
|
||||
return torch.tensor(x, device=device)
|
||||
|
||||
|
||||
@@ -72,6 +82,8 @@ def to_numpy(x):
|
||||
if isinstance(x, TensorWrapper):
|
||||
return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
|
||||
elif isinstance(x, torch.Tensor):
|
||||
if x.dtype is torch.bfloat16:
|
||||
return x.cpu().float().numpy()
|
||||
return x.cpu().numpy()
|
||||
else:
|
||||
raise ValueError(f"Not a triton-compatible tensor: {x}")
|
||||
@@ -84,19 +96,30 @@ def patch_kernel(template, to_replace):
|
||||
return kernel
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
|
||||
def check_type_supported(dtype):
|
||||
'''
|
||||
skip test if dtype is not supported on the current device
|
||||
'''
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
|
||||
def test_empty_kernel(dtype_x, device='cuda'):
|
||||
SIZE = 128
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, SIZE: tl.constexpr):
|
||||
pass
|
||||
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device)
|
||||
check_type_supported(dtype_x)
|
||||
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x)
|
||||
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
|
||||
|
||||
|
||||
# generic test functions
|
||||
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
||||
check_type_supported(dtype_x) # early return if dtype_x is not supported
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
|
||||
@@ -115,8 +138,8 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
||||
# reference result
|
||||
z_ref = eval(expr if numpy_expr is None else numpy_expr)
|
||||
# triton result
|
||||
x_tri = to_triton(x, device=device)
|
||||
z_tri = to_triton(np.empty_like(z_ref), device=device)
|
||||
x_tri = to_triton(x, device=device, dst_type=dtype_x)
|
||||
z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)
|
||||
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
|
||||
# compare
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
@@ -154,6 +177,8 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
|
||||
|
||||
|
||||
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
|
||||
check_type_supported(dtype_x) # early return if dtype_x is not supported
|
||||
check_type_supported(dtype_y)
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
|
||||
@@ -180,8 +205,8 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y=
|
||||
if dtype_z is not None:
|
||||
z_ref = z_ref.astype(dtype_z)
|
||||
# triton result
|
||||
x_tri = to_triton(x, device=device)
|
||||
y_tri = to_triton(y, device=device)
|
||||
x_tri = to_triton(x, device=device, dst_type=dtype_x)
|
||||
y_tri = to_triton(y, device=device, dst_type=dtype_y)
|
||||
z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device)
|
||||
kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4)
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01)
|
||||
@@ -193,15 +218,20 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
# remainders than stock LLVM. We currently don't expect to match it
|
||||
# bit-for-bit.
|
||||
return (dtype_x, dtype_y) in [
|
||||
('int32', 'bfloat16'),
|
||||
('int32', 'float16'),
|
||||
('int32', 'float32'),
|
||||
('int64', 'bfloat16'),
|
||||
('int64', 'float16'),
|
||||
('int64', 'float32'),
|
||||
('int64', 'float64'),
|
||||
('uint16', 'bfloat16'),
|
||||
('uint16', 'float16'),
|
||||
('uint16', 'float32'),
|
||||
('uint32', 'bfloat16'),
|
||||
('uint32', 'float16'),
|
||||
('uint32', 'float32'),
|
||||
('uint64', 'bfloat16'),
|
||||
('uint64', 'float16'),
|
||||
('uint64', 'float32'),
|
||||
('uint64', 'float64'),
|
||||
@@ -215,15 +245,15 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||
(dtype_x, dtype_y, op)
|
||||
for op in ['+', '-', '*', '/', '%']
|
||||
for dtype_x in dtypes
|
||||
for dtype_y in dtypes
|
||||
for dtype_x in dtypes + ['bfloat16']
|
||||
for dtype_y in dtypes + ['bfloat16']
|
||||
])
|
||||
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f' x {op} y'
|
||||
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
|
||||
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
|
||||
numpy_expr = 'np.fmod(x, y)'
|
||||
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
|
||||
elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', 'bfloat16'):
|
||||
# Triton promotes 16-bit floating-point / and % to 32-bit because there
|
||||
# are no native div or FRem operations on float16. Since we have to
|
||||
# convert anyway, we may as well take the accuracy bump.
|
||||
@@ -266,8 +296,8 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'):
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||
(dtype_x, dtype_y, op)
|
||||
for op in ['&', '|', '^']
|
||||
for dtype_x in dtypes
|
||||
for dtype_y in dtypes
|
||||
for dtype_x in dtypes + ['bfloat16']
|
||||
for dtype_y in dtypes + ['bfloat16']
|
||||
])
|
||||
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f'x {op} y'
|
||||
@@ -337,7 +367,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
||||
# test unary ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, expr", [
|
||||
(dtype_x, ' -x') for dtype_x in dtypes
|
||||
(dtype_x, ' -x') for dtype_x in dtypes + ['bfloat16']
|
||||
] + [
|
||||
(dtype_x, ' ~x') for dtype_x in int_dtypes
|
||||
])
|
||||
@@ -732,9 +762,10 @@ def test_f16_to_f8_rounding():
|
||||
@pytest.mark.parametrize("op, dtype_str, shape",
|
||||
[(op, dtype, shape)
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for dtype in dtypes
|
||||
for dtype in dtypes + ['bfloat16']
|
||||
for shape in [32, 64, 128, 512]])
|
||||
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -752,9 +783,18 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
# numpy result
|
||||
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
||||
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
||||
z_tri_dtype_str = z_dtype_str
|
||||
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
||||
z_dtype_str = 'float32'
|
||||
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
||||
# trunc mantissa for a fair comparison of accuracy
|
||||
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
|
||||
z_tri_dtype_str = 'bfloat16'
|
||||
else:
|
||||
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), device=device)
|
||||
z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs),
|
||||
device=device, dst_type=z_tri_dtype_str)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
|
||||
z_tri = to_numpy(z_tri)
|
||||
# compare
|
||||
@@ -770,7 +810,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
|
||||
|
||||
reduce_configs1 = [
|
||||
(op, dtype, (1, 1024), axis) for dtype in dtypes
|
||||
(op, dtype, (1, 1024), axis) for dtype in dtypes + ['bfloat16']
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for axis in [1]
|
||||
]
|
||||
@@ -805,11 +845,19 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
||||
z_tri_dtype_str = z_dtype_str
|
||||
# numpy result
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
||||
z_dtype_str = 'float32'
|
||||
z_tri_dtype_str = 'bfloat16'
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# trunc mantissa for a fair comparison of accuracy
|
||||
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
|
||||
else:
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
|
||||
device=device)
|
||||
device=device, dst_type=z_tri_dtype_str)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||
z_tri = to_numpy(z_tri)
|
||||
# compare
|
||||
@@ -834,10 +882,11 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, shape, perm",
|
||||
[(dtype, shape, perm)
|
||||
for dtype in ['float16', 'float32']
|
||||
for dtype in ['bfloat16', 'float16', 'float32']
|
||||
for shape in [(64, 64), (128, 128)]
|
||||
for perm in [(1, 0)]])
|
||||
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -852,16 +901,16 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# input
|
||||
x = numpy_random(shape, dtype_str=dtype_str)
|
||||
# triton result
|
||||
z_tri = to_triton(np.empty_like(x), device=device)
|
||||
z_tri_contiguous = to_triton(np.empty_like(x), device=device)
|
||||
x_tri = to_triton(x, device=device)
|
||||
z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
||||
z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
||||
x_tri = to_triton(x, device=device, dst_type=dtype_str)
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
z_tri, z_tri.stride(1), z_tri.stride(0),
|
||||
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||
pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0),
|
||||
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
|
||||
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||
# torch result
|
||||
# numpy result
|
||||
z_ref = x.transpose(*perm)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
@@ -1038,8 +1087,10 @@ def test_arange(start, device='cuda'):
|
||||
# Testing masked loads with an intermate copy to shared memory run.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
M = 32
|
||||
N = 32
|
||||
K = 16
|
||||
|
@@ -2,18 +2,22 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||
[
|
||||
(M, N, dtype, mode) for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['float16', 'float32']
|
||||
for dtype in ['bfloat16', 'float16', 'float32']
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
)
|
||||
def test_op(M, N, dtype, mode):
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and dtype == "bfloat16":
|
||||
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||
dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype]
|
||||
# create inputs
|
||||
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
|
||||
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
|
||||
|
Reference in New Issue
Block a user