uint8, uint16, uint32, and uint64 in kernels (#413)
A forthcoming PR will update the RNG to use these types. Also: - Add tests for the `//`, `<<`, and `>>` operators. - Change `TensorWrapper` to unwrap objects when the resulting object would be simpler. - Clean up `throw_unreachable`, since it was triggering compiler warnings.
This commit is contained in:
committed by
GitHub
parent
d8db0308cb
commit
0ab9d67bad
@@ -1,7 +1,7 @@
|
||||
import copy
|
||||
import itertools
|
||||
import re
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -10,17 +10,20 @@ from numpy.random import RandomState
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.code_gen import TensorWrapper, reinterpret
|
||||
|
||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
float_dtypes = ['float16', 'float32', 'float64']
|
||||
dtypes = int_dtypes + float_dtypes
|
||||
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||
|
||||
|
||||
def _bitwidth(dtype: str) -> int:
|
||||
# ex.: "int64" -> 64
|
||||
return int(re.search(r'(\d+)$', dtype).group(1))
|
||||
|
||||
|
||||
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
|
||||
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None):
|
||||
"""
|
||||
Override `rs` if you're calling this function twice and don't want the same
|
||||
result for both calls.
|
||||
@@ -30,9 +33,11 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
|
||||
if rs is None:
|
||||
rs = RandomState(seed=17)
|
||||
dtype = getattr(np, dtype_str)
|
||||
if dtype_str in int_dtypes:
|
||||
if dtype_str in int_dtypes + uint_dtypes:
|
||||
iinfo = np.iinfo(getattr(np, dtype_str))
|
||||
x = rs.randint(iinfo.min, iinfo.max, shape, dtype=dtype)
|
||||
low = iinfo.min if low is None else max(low, iinfo.min)
|
||||
high = iinfo.max if high is None else min(high, iinfo.max)
|
||||
x = rs.randint(low, high, shape, dtype=dtype)
|
||||
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
|
||||
return x
|
||||
elif dtype_str in float_dtypes:
|
||||
@@ -41,15 +46,31 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
|
||||
raise RuntimeError(f'Unknown dtype {dtype_str}')
|
||||
|
||||
|
||||
def to_triton(x: np.ndarray, device='cuda') -> torch.Tensor:
|
||||
# For now, this always converts to a torch tensor, but when we add unsigned
|
||||
# integers, it will also support TensorWrapper, since torch doesn't have
|
||||
# unsigned support.
|
||||
return torch.tensor(x, device=device)
|
||||
def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]:
|
||||
t = x.dtype.name
|
||||
if t in uint_dtypes:
|
||||
signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16"
|
||||
x_signed = x.astype(getattr(np, signed_type_name))
|
||||
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
|
||||
else:
|
||||
return torch.tensor(x, device=device)
|
||||
|
||||
|
||||
def torch_dtype_name(dtype) -> str:
|
||||
if isinstance(dtype, triton.language.dtype):
|
||||
return dtype.name
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
# 'torch.int64' -> 'int64'
|
||||
m = re.match(r'^torch\.(\w+)$', str(dtype))
|
||||
return m.group(1)
|
||||
else:
|
||||
raise TypeError(f'not a triton or torch dtype: {type(dtype)}')
|
||||
|
||||
|
||||
def to_numpy(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
if isinstance(x, TensorWrapper):
|
||||
return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
|
||||
elif isinstance(x, torch.Tensor):
|
||||
return x.cpu().numpy()
|
||||
else:
|
||||
raise ValueError(f"Not a triton-compatible tensor: {x}")
|
||||
@@ -103,18 +124,33 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
|
||||
Given two dtype strings, returns the numpy dtype Triton thinks binary
|
||||
operations on the two types should return. Returns None if the return value
|
||||
matches numpy. This is generally needed because Triton and pytorch return
|
||||
narrower floating point types than numpy in mixed operations.
|
||||
narrower floating point types than numpy in mixed operations, and because
|
||||
Triton follows C/C++ semantics around mixed signed/unsigned operations, and
|
||||
numpy/pytorch do not.
|
||||
"""
|
||||
overrides = {
|
||||
('float16', 'int16'): np.float16,
|
||||
('float16', 'int32'): np.float16,
|
||||
('float16', 'int64'): np.float16,
|
||||
('float16', 'uint16'): np.float16,
|
||||
('float16', 'uint32'): np.float16,
|
||||
('float16', 'uint64'): np.float16,
|
||||
('int8', 'uint8'): np.uint8,
|
||||
('int8', 'uint16'): np.uint16,
|
||||
('int8', 'uint32'): np.uint32,
|
||||
('int8', 'uint64'): np.uint64,
|
||||
('int16', 'uint16'): np.uint16,
|
||||
('int16', 'uint32'): np.uint32,
|
||||
('int16', 'uint64'): np.uint64,
|
||||
('int32', 'uint32'): np.uint32,
|
||||
('int32', 'uint64'): np.uint64,
|
||||
('int64', 'uint64'): np.uint64,
|
||||
}
|
||||
key = (a, b) if a < b else (b, a)
|
||||
return overrides.get(key)
|
||||
|
||||
|
||||
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda'):
|
||||
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
@@ -129,7 +165,7 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y=
|
||||
# inputs
|
||||
rs = RandomState(17)
|
||||
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
|
||||
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
|
||||
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high)
|
||||
if mode_x == 'nan':
|
||||
x[:] = float('nan')
|
||||
if mode_y == 'nan':
|
||||
@@ -158,6 +194,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
('int64', 'float16'),
|
||||
('int64', 'float32'),
|
||||
('int64', 'float64'),
|
||||
('uint16', 'float16'),
|
||||
('uint16', 'float32'),
|
||||
('uint32', 'float16'),
|
||||
('uint32', 'float32'),
|
||||
('uint64', 'float16'),
|
||||
('uint64', 'float32'),
|
||||
('uint64', 'float64'),
|
||||
]
|
||||
|
||||
# ---------------
|
||||
@@ -171,7 +214,7 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
])
|
||||
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f' x {op} y'
|
||||
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
|
||||
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
|
||||
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
|
||||
numpy_expr = 'np.fmod(x, y)'
|
||||
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
|
||||
@@ -179,15 +222,38 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
# are no native div or FRem operations on float16. Since we have to
|
||||
# convert anyway, we may as well take the accuracy bump.
|
||||
numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)'
|
||||
elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
||||
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
||||
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
||||
else:
|
||||
numpy_expr = None
|
||||
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
|
||||
with pytest.raises(AssertionError, match='Not equal to tolerance'):
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
elif (op in ('%', '/') and
|
||||
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
|
||||
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
|
||||
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
|
||||
else:
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y",
|
||||
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
|
||||
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
|
||||
)
|
||||
def test_floordiv(dtype_x, dtype_y, device='cuda'):
|
||||
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
|
||||
# through to //, so we have to use a nonstandard expression to get a
|
||||
# reference result for //.
|
||||
expr = 'x // y'
|
||||
numpy_expr = '((x - np.fmod(x, y)) / y)'
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test bitwise ops
|
||||
# ---------------
|
||||
@@ -199,13 +265,33 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
])
|
||||
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f'x {op} y'
|
||||
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
||||
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
||||
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
||||
else:
|
||||
numpy_expr = None
|
||||
if 'float' in dtype_x + dtype_y:
|
||||
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
|
||||
# The CompilationError must have been caused by a C++ exception with this text.
|
||||
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
|
||||
else:
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||
(dtype_x, dtype_y, op)
|
||||
for op in ['<<', '>>']
|
||||
for dtype_x in int_dtypes + uint_dtypes
|
||||
for dtype_y in int_dtypes + uint_dtypes
|
||||
])
|
||||
def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f'x {op} y'
|
||||
bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y))
|
||||
dtype_z = f'uint{bw}'
|
||||
numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})'
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65)
|
||||
|
||||
|
||||
# ---------------
|
||||
@@ -230,7 +316,13 @@ ops = ['==', '!=', '>', '<', '>=', '<=']
|
||||
])
|
||||
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
||||
expr = f'x {op} y'
|
||||
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
||||
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
||||
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
||||
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
||||
else:
|
||||
numpy_expr = None
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
@@ -238,9 +330,9 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, expr", [
|
||||
(dtype_x, ' -x') for dtype_x in dtypes
|
||||
] + [\
|
||||
] + [
|
||||
(dtype_x, ' ~x') for dtype_x in int_dtypes
|
||||
])
|
||||
])
|
||||
def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
_test_unary(dtype_x, expr, device=device)
|
||||
|
||||
@@ -275,8 +367,9 @@ def make_ptr_str(name, shape):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expr, dtype_str", [
|
||||
(f'x[{s}]', 'int32')
|
||||
(f'x[{s}]', d)
|
||||
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
||||
for d in ['int32', 'uint32', 'uint16']
|
||||
])
|
||||
def test_index1d(expr, dtype_str, device='cuda'):
|
||||
rank_x = expr.count(':')
|
||||
@@ -364,9 +457,9 @@ def test_tuples():
|
||||
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
||||
[
|
||||
('add', 'float16', mode),
|
||||
('add', 'int32', mode), ('add', 'float32', mode),
|
||||
('max', 'int32', mode), ('max', 'float32', mode),
|
||||
('min', 'int32', mode), ('min', 'float32', mode),
|
||||
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
|
||||
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
|
||||
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
@@ -409,7 +502,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
if exact:
|
||||
assert z_ref.item() == to_numpy(z_tri).item()
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.001)
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
|
||||
# ---------------
|
||||
@@ -423,8 +516,11 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
('float32', 'bfloat16', False),
|
||||
('bfloat16', 'float32', False),
|
||||
('float32', 'int32', True),
|
||||
]
|
||||
)
|
||||
] + [
|
||||
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
||||
] + [
|
||||
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||
x0 = 43 if dtype_x in int_dtypes else 43.5
|
||||
@@ -487,7 +583,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, shape, axis", [
|
||||
('float32', (1, 1024), 1)
|
||||
(dtype, (1, 1024), 1) for dtype in ['float32', 'uint32']
|
||||
])
|
||||
def test_reduce2d(dtype_str, shape, axis, device='cuda'):
|
||||
# triton kernel
|
||||
@@ -762,3 +858,43 @@ def test_noop(device='cuda'):
|
||||
pass
|
||||
x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
|
||||
kernel[(1, )](x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'),
|
||||
(2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'),
|
||||
(-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
pgm = kernel[(1, )](value, x)
|
||||
|
||||
# Parse out the type of the 'VALUE' parameter from the Triton IR.
|
||||
triton_ir = pgm.asm['ttir']
|
||||
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
|
||||
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
|
||||
assert ir_value_type == value_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, overflow",
|
||||
[(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
|
||||
)
|
||||
def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
|
||||
if overflow:
|
||||
with pytest.raises(RuntimeError, match='integer overflow'):
|
||||
kernel[(1, )](value, x)
|
||||
else:
|
||||
kernel[(1, )](value, x)
|
||||
|
@@ -147,6 +147,7 @@ def test_rand(size, seed, device='cuda'):
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
assert all((x >= 0) & (x <= 1))
|
||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test normal PRNG
|
||||
|
Reference in New Issue
Block a user