[BACKEND] Better bf16 support (#588)

This commit is contained in:
daadaada
2022-07-20 12:22:37 +08:00
committed by GitHub
parent 86cab58d89
commit 9b2bc88d11
6 changed files with 180 additions and 62 deletions

View File

@@ -33,27 +33,37 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h
shape = (shape, )
if rs is None:
rs = RandomState(seed=17)
dtype = getattr(np, dtype_str)
if dtype_str in int_dtypes + uint_dtypes:
iinfo = np.iinfo(getattr(np, dtype_str))
low = iinfo.min if low is None else max(low, iinfo.min)
high = iinfo.max if high is None else min(high, iinfo.max)
dtype = getattr(np, dtype_str)
x = rs.randint(low, high, shape, dtype=dtype)
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
return x
elif dtype_str in float_dtypes:
return rs.normal(0, 1, shape).astype(dtype)
return rs.normal(0, 1, shape).astype(dtype_str)
elif dtype_str == 'bfloat16':
return (rs.normal(0, 1, shape).astype('float32').view('uint32')
& np.uint32(0xffff0000)).view('float32')
else:
raise RuntimeError(f'Unknown dtype {dtype_str}')
def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]:
def to_triton(x: np.ndarray, device='cuda', dst_type=None) -> Union[TensorWrapper, torch.Tensor]:
'''
Note: We need dst_type becasue the type of x can be different from dst_type.
For example: x is of type `float32`, dst_type is `bfloat16`.
If dst_type is None, we infer dst_type from x.
'''
t = x.dtype.name
if t in uint_dtypes:
signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16"
x_signed = x.astype(getattr(np, signed_type_name))
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
else:
if t == 'float32' and dst_type == 'bfloat16':
return torch.tensor(x, device=device).bfloat16()
return torch.tensor(x, device=device)
@@ -72,6 +82,8 @@ def to_numpy(x):
if isinstance(x, TensorWrapper):
return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
elif isinstance(x, torch.Tensor):
if x.dtype is torch.bfloat16:
return x.cpu().float().numpy()
return x.cpu().numpy()
else:
raise ValueError(f"Not a triton-compatible tensor: {x}")
@@ -84,19 +96,30 @@ def patch_kernel(template, to_replace):
return kernel
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
def check_type_supported(dtype):
'''
skip test if dtype is not supported on the current device
'''
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128
@triton.jit
def kernel(X, SIZE: tl.constexpr):
pass
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device)
check_type_supported(dtype_x)
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x)
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
# generic test functions
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
check_type_supported(dtype_x) # early return if dtype_x is not supported
SIZE = 128
# define the kernel / launch-grid
@@ -115,8 +138,8 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
# reference result
z_ref = eval(expr if numpy_expr is None else numpy_expr)
# triton result
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.empty_like(z_ref), device=device)
x_tri = to_triton(x, device=device, dst_type=dtype_x)
z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
# compare
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
@@ -154,6 +177,8 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
check_type_supported(dtype_x) # early return if dtype_x is not supported
check_type_supported(dtype_y)
SIZE = 128
# define the kernel / launch-grid
@@ -180,8 +205,8 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y=
if dtype_z is not None:
z_ref = z_ref.astype(dtype_z)
# triton result
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
x_tri = to_triton(x, device=device, dst_type=dtype_x)
y_tri = to_triton(y, device=device, dst_type=dtype_y)
z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device)
kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01)
@@ -193,15 +218,20 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
# remainders than stock LLVM. We currently don't expect to match it
# bit-for-bit.
return (dtype_x, dtype_y) in [
('int32', 'bfloat16'),
('int32', 'float16'),
('int32', 'float32'),
('int64', 'bfloat16'),
('int64', 'float16'),
('int64', 'float32'),
('int64', 'float64'),
('uint16', 'bfloat16'),
('uint16', 'float16'),
('uint16', 'float32'),
('uint32', 'bfloat16'),
('uint32', 'float16'),
('uint32', 'float32'),
('uint64', 'bfloat16'),
('uint64', 'float16'),
('uint64', 'float32'),
('uint64', 'float64'),
@@ -215,15 +245,15 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op)
for op in ['+', '-', '*', '/', '%']
for dtype_x in dtypes
for dtype_y in dtypes
for dtype_x in dtypes + ['bfloat16']
for dtype_y in dtypes + ['bfloat16']
])
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
expr = f' x {op} y'
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
numpy_expr = 'np.fmod(x, y)'
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', 'bfloat16'):
# Triton promotes 16-bit floating-point / and % to 32-bit because there
# are no native div or FRem operations on float16. Since we have to
# convert anyway, we may as well take the accuracy bump.
@@ -266,8 +296,8 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'):
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op)
for op in ['&', '|', '^']
for dtype_x in dtypes
for dtype_y in dtypes
for dtype_x in dtypes + ['bfloat16']
for dtype_y in dtypes + ['bfloat16']
])
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
expr = f'x {op} y'
@@ -337,7 +367,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
# test unary ops
# ---------------
@pytest.mark.parametrize("dtype_x, expr", [
(dtype_x, ' -x') for dtype_x in dtypes
(dtype_x, ' -x') for dtype_x in dtypes + ['bfloat16']
] + [
(dtype_x, ' ~x') for dtype_x in int_dtypes
])
@@ -732,9 +762,10 @@ def test_f16_to_f8_rounding():
@pytest.mark.parametrize("op, dtype_str, shape",
[(op, dtype, shape)
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for dtype in dtypes
for dtype in dtypes + ['bfloat16']
for shape in [32, 64, 128, 512]])
def test_reduce1d(op, dtype_str, shape, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
# triton kernel
@triton.jit
@@ -752,9 +783,18 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
'argmin': np.argmin, 'argmax': np.argmax}[op]
# numpy result
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
z_tri_dtype_str = z_dtype_str
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
z_dtype_str = 'float32'
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
# trunc mantissa for a fair comparison of accuracy
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
z_tri_dtype_str = 'bfloat16'
else:
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
# triton result
z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), device=device)
z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs),
device=device, dst_type=z_tri_dtype_str)
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
z_tri = to_numpy(z_tri)
# compare
@@ -770,7 +810,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
reduce_configs1 = [
(op, dtype, (1, 1024), axis) for dtype in dtypes
(op, dtype, (1, 1024), axis) for dtype in dtypes + ['bfloat16']
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for axis in [1]
]
@@ -805,11 +845,19 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
'argmin': np.argmin, 'argmax': np.argmax}[op]
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
z_tri_dtype_str = z_dtype_str
# numpy result
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
z_dtype_str = 'float32'
z_tri_dtype_str = 'bfloat16'
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# trunc mantissa for a fair comparison of accuracy
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
else:
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# triton result
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
device=device)
device=device, dst_type=z_tri_dtype_str)
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
z_tri = to_numpy(z_tri)
# compare
@@ -834,10 +882,11 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
@pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm)
for dtype in ['float16', 'float32']
for dtype in ['bfloat16', 'float16', 'float32']
for shape in [(64, 64), (128, 128)]
for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
# triton kernel
@triton.jit
@@ -852,16 +901,16 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# input
x = numpy_random(shape, dtype_str=dtype_str)
# triton result
z_tri = to_triton(np.empty_like(x), device=device)
z_tri_contiguous = to_triton(np.empty_like(x), device=device)
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
x_tri = to_triton(x, device=device, dst_type=dtype_str)
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
z_tri, z_tri.stride(1), z_tri.stride(0),
BLOCK_M=shape[0], BLOCK_N=shape[1])
pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0),
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
BLOCK_M=shape[0], BLOCK_N=shape[1])
# torch result
# numpy result
z_ref = x.transpose(*perm)
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
@@ -1038,8 +1087,10 @@ def test_arange(start, device='cuda'):
# Testing masked loads with an intermate copy to shared memory run.
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'):
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
M = 32
N = 32
K = 16

View File

@@ -2,18 +2,22 @@ import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize("M, N, dtype, mode",
[
(M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']
for dtype in ['bfloat16', 'float16', 'float32']
for mode in ['forward', 'backward']
]
)
def test_op(M, N, dtype, mode):
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and dtype == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype]
# create inputs
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')