uint8, uint16, uint32, and uint64 in kernels (#413)

A forthcoming PR will update the RNG to use these types.

Also:
- Add tests for the `//`, `<<`, and `>>` operators.
- Change `TensorWrapper` to unwrap objects when the resulting object would be simpler.
- Clean up `throw_unreachable`, since it was triggering compiler warnings.
This commit is contained in:
Madeleine Thompson
2022-01-05 15:27:17 -08:00
committed by GitHub
parent d8db0308cb
commit 0ab9d67bad
12 changed files with 444 additions and 110 deletions

View File

@@ -1,7 +1,7 @@
import copy
import itertools
import re
from typing import Optional
from typing import Optional, Union
import numpy as np
import pytest
@@ -10,17 +10,20 @@ from numpy.random import RandomState
import triton
import triton.language as tl
from triton.code_gen import TensorWrapper, reinterpret
int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + float_dtypes
dtypes = int_dtypes + uint_dtypes + float_dtypes
def _bitwidth(dtype: str) -> int:
# ex.: "int64" -> 64
return int(re.search(r'(\d+)$', dtype).group(1))
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None):
"""
Override `rs` if you're calling this function twice and don't want the same
result for both calls.
@@ -30,9 +33,11 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
if rs is None:
rs = RandomState(seed=17)
dtype = getattr(np, dtype_str)
if dtype_str in int_dtypes:
if dtype_str in int_dtypes + uint_dtypes:
iinfo = np.iinfo(getattr(np, dtype_str))
x = rs.randint(iinfo.min, iinfo.max, shape, dtype=dtype)
low = iinfo.min if low is None else max(low, iinfo.min)
high = iinfo.max if high is None else min(high, iinfo.max)
x = rs.randint(low, high, shape, dtype=dtype)
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
return x
elif dtype_str in float_dtypes:
@@ -41,15 +46,31 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
raise RuntimeError(f'Unknown dtype {dtype_str}')
def to_triton(x: np.ndarray, device='cuda') -> torch.Tensor:
# For now, this always converts to a torch tensor, but when we add unsigned
# integers, it will also support TensorWrapper, since torch doesn't have
# unsigned support.
return torch.tensor(x, device=device)
def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]:
t = x.dtype.name
if t in uint_dtypes:
signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16"
x_signed = x.astype(getattr(np, signed_type_name))
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
else:
return torch.tensor(x, device=device)
def torch_dtype_name(dtype) -> str:
if isinstance(dtype, triton.language.dtype):
return dtype.name
elif isinstance(dtype, torch.dtype):
# 'torch.int64' -> 'int64'
m = re.match(r'^torch\.(\w+)$', str(dtype))
return m.group(1)
else:
raise TypeError(f'not a triton or torch dtype: {type(dtype)}')
def to_numpy(x):
if isinstance(x, torch.Tensor):
if isinstance(x, TensorWrapper):
return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
elif isinstance(x, torch.Tensor):
return x.cpu().numpy()
else:
raise ValueError(f"Not a triton-compatible tensor: {x}")
@@ -103,18 +124,33 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
Given two dtype strings, returns the numpy dtype Triton thinks binary
operations on the two types should return. Returns None if the return value
matches numpy. This is generally needed because Triton and pytorch return
narrower floating point types than numpy in mixed operations.
narrower floating point types than numpy in mixed operations, and because
Triton follows C/C++ semantics around mixed signed/unsigned operations, and
numpy/pytorch do not.
"""
overrides = {
('float16', 'int16'): np.float16,
('float16', 'int32'): np.float16,
('float16', 'int64'): np.float16,
('float16', 'uint16'): np.float16,
('float16', 'uint32'): np.float16,
('float16', 'uint64'): np.float16,
('int8', 'uint8'): np.uint8,
('int8', 'uint16'): np.uint16,
('int8', 'uint32'): np.uint32,
('int8', 'uint64'): np.uint64,
('int16', 'uint16'): np.uint16,
('int16', 'uint32'): np.uint32,
('int16', 'uint64'): np.uint64,
('int32', 'uint32'): np.uint32,
('int32', 'uint64'): np.uint64,
('int64', 'uint64'): np.uint64,
}
key = (a, b) if a < b else (b, a)
return overrides.get(key)
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda'):
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
@@ -129,7 +165,7 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y=
# inputs
rs = RandomState(17)
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high)
if mode_x == 'nan':
x[:] = float('nan')
if mode_y == 'nan':
@@ -158,6 +194,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
('int64', 'float16'),
('int64', 'float32'),
('int64', 'float64'),
('uint16', 'float16'),
('uint16', 'float32'),
('uint32', 'float16'),
('uint32', 'float32'),
('uint64', 'float16'),
('uint64', 'float32'),
('uint64', 'float64'),
]
# ---------------
@@ -171,7 +214,7 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
])
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
expr = f' x {op} y'
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
numpy_expr = 'np.fmod(x, y)'
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
@@ -179,15 +222,38 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
# are no native div or FRem operations on float16. Since we have to
# convert anyway, we may as well take the accuracy bump.
numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)'
elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
else:
numpy_expr = None
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
with pytest.raises(AssertionError, match='Not equal to tolerance'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
elif (op in ('%', '/') and
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
else:
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
@pytest.mark.parametrize("dtype_x, dtype_y",
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
)
def test_floordiv(dtype_x, dtype_y, device='cuda'):
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
# through to //, so we have to use a nonstandard expression to get a
# reference result for //.
expr = 'x // y'
numpy_expr = '((x - np.fmod(x, y)) / y)'
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
# ---------------
# test bitwise ops
# ---------------
@@ -199,13 +265,33 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
])
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
expr = f'x {op} y'
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
else:
numpy_expr = None
if 'float' in dtype_x + dtype_y:
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
# The CompilationError must have been caused by a C++ exception with this text.
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
else:
_test_binary(dtype_x, dtype_y, expr, device=device)
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op)
for op in ['<<', '>>']
for dtype_x in int_dtypes + uint_dtypes
for dtype_y in int_dtypes + uint_dtypes
])
def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
expr = f'x {op} y'
bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y))
dtype_z = f'uint{bw}'
numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})'
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65)
# ---------------
@@ -230,7 +316,13 @@ ops = ['==', '!=', '>', '<', '>=', '<=']
])
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
expr = f'x {op} y'
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
else:
numpy_expr = None
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
# ---------------
@@ -238,9 +330,9 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
# ---------------
@pytest.mark.parametrize("dtype_x, expr", [
(dtype_x, ' -x') for dtype_x in dtypes
] + [\
] + [
(dtype_x, ' ~x') for dtype_x in int_dtypes
])
])
def test_unary_op(dtype_x, expr, device='cuda'):
_test_unary(dtype_x, expr, device=device)
@@ -275,8 +367,9 @@ def make_ptr_str(name, shape):
@pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', 'int32')
(f'x[{s}]', d)
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
for d in ['int32', 'uint32', 'uint16']
])
def test_index1d(expr, dtype_str, device='cuda'):
rank_x = expr.count(':')
@@ -364,9 +457,9 @@ def test_tuples():
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
[
('add', 'float16', mode),
('add', 'int32', mode), ('add', 'float32', mode),
('max', 'int32', mode), ('max', 'float32', mode),
('min', 'int32', mode), ('min', 'float32', mode),
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
@@ -409,7 +502,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
if exact:
assert z_ref.item() == to_numpy(z_tri).item()
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.001)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# ---------------
@@ -423,8 +516,11 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
('float32', 'bfloat16', False),
('bfloat16', 'float32', False),
('float32', 'int32', True),
]
)
] + [
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
] + [
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
x0 = 43 if dtype_x in int_dtypes else 43.5
@@ -487,7 +583,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'):
@pytest.mark.parametrize("dtype_str, shape, axis", [
('float32', (1, 1024), 1)
(dtype, (1, 1024), 1) for dtype in ['float32', 'uint32']
])
def test_reduce2d(dtype_str, shape, axis, device='cuda'):
# triton kernel
@@ -762,3 +858,43 @@ def test_noop(device='cuda'):
pass
x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
kernel[(1, )](x)
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'),
(2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
x = torch.tensor([3.14159], device='cuda')
pgm = kernel[(1, )](value, x)
# Parse out the type of the 'VALUE' parameter from the Triton IR.
triton_ir = pgm.asm['ttir']
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
assert ir_value_type == value_type
@pytest.mark.parametrize(
"value, overflow",
[(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
)
def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
x = torch.tensor([3.14159], device='cuda')
if overflow:
with pytest.raises(RuntimeError, match='integer overflow'):
kernel[(1, )](value, x)
else:
kernel[(1, )](value, x)

View File

@@ -147,6 +147,7 @@ def test_rand(size, seed, device='cuda'):
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
assert all((x >= 0) & (x <= 1))
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG