uint8, uint16, uint32, and uint64 in kernels (#413)
A forthcoming PR will update the RNG to use these types. Also: - Add tests for the `//`, `<<`, and `>>` operators. - Change `TensorWrapper` to unwrap objects when the resulting object would be simpler. - Clean up `throw_unreachable`, since it was triggering compiler warnings.
This commit is contained in:
committed by
GitHub
parent
d8db0308cb
commit
0ab9d67bad
@@ -109,6 +109,24 @@ std::string pow2_divisor(long N){
|
||||
return "1";
|
||||
}
|
||||
|
||||
// Returns something like "int16", whether dtype is a torch.dtype or
|
||||
// triton.language.dtype.
|
||||
std::string dtype_cache_key_part(const py::object& dtype) {
|
||||
if (py::hasattr(dtype, "cache_key_part")) {
|
||||
// Presumed to be a triton.language.dtype.
|
||||
return std::string(py::str(py::getattr(dtype, "cache_key_part")));
|
||||
} else {
|
||||
// Remove 'torch.' prefix from repr of torch.dtype.
|
||||
py::object repr = py::repr(dtype);
|
||||
size_t repr_len = PyUnicode_GET_LENGTH(repr.ptr());
|
||||
const char* repr_ptr = (const char*)PyUnicode_1BYTE_DATA(repr.ptr());
|
||||
if (repr_len <= 6 || strncmp(repr_ptr, "torch.", 6)) {
|
||||
throw std::logic_error("invalid dtype: " + std::string(repr_ptr, repr_len));
|
||||
}
|
||||
return std::string(repr_ptr + 6, repr_len - 6);
|
||||
}
|
||||
}
|
||||
|
||||
// Launch
|
||||
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
|
||||
@@ -136,22 +154,34 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
cache_key += "1";
|
||||
continue;
|
||||
}
|
||||
// long and int have different kernels
|
||||
if(!overflow & (std::abs(value) <= 0xffffffff)){
|
||||
// int32, uint32, int64, and uint64 have different kernels
|
||||
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
|
||||
cache_key += "int32";
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
}
|
||||
else{
|
||||
} else if (!overflow && 0x8000'0000LL <= value && value <= 0xFFFF'FFFFLL) {
|
||||
cache_key += "uint32";
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
} else if (!overflow) {
|
||||
cache_key += "int64";
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
if(overflow){
|
||||
unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr);
|
||||
std::memcpy(&value, &uvalue, 8);
|
||||
}
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
} else {
|
||||
if (PyErr_Occurred()) {
|
||||
throw std::logic_error("An error occurred?");
|
||||
}
|
||||
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
|
||||
if (PyErr_Occurred()) {
|
||||
throw std::runtime_error("integer overflow in argument: " + std::string(py::str(arg)));
|
||||
}
|
||||
cache_key += "uint64";
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &unsigned_value, 8);
|
||||
params_ptr += 8;
|
||||
}
|
||||
if(!specialize)
|
||||
continue;
|
||||
@@ -185,12 +215,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
py::object dtype = arg.attr("dtype");
|
||||
py::object repr = py::repr(dtype);
|
||||
assert(!strncmp((const char*)PyUnicode_1BYTE_DATA(repr.ptr()), "torch.", 6));
|
||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
|
||||
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
|
||||
cache_key += std::string(start, len);
|
||||
cache_key += dtype_cache_key_part(arg.attr("dtype"));
|
||||
cache_key += "*";
|
||||
cache_key += "[multipleof(";
|
||||
cache_key += pow2_divisor(value);
|
||||
@@ -628,6 +653,10 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
|
||||
.def("get_int32", &ir::type::get_int32_ty, ret::reference)
|
||||
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
|
||||
.def("get_uint8", &ir::type::get_uint8_ty, ret::reference)
|
||||
.def("get_uint16", &ir::type::get_uint16_ty, ret::reference)
|
||||
.def("get_uint32", &ir::type::get_uint32_ty, ret::reference)
|
||||
.def("get_uint64", &ir::type::get_uint64_ty, ret::reference)
|
||||
|
||||
.def("is_void", &ir::type::is_void_ty)
|
||||
.def("is_fp8", &ir::type::is_fp8_ty)
|
||||
@@ -635,11 +664,15 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("is_bf16", &ir::type::is_bf16_ty)
|
||||
.def("is_fp32", &ir::type::is_fp32_ty)
|
||||
.def("is_fp64", &ir::type::is_fp64_ty)
|
||||
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
|
||||
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
|
||||
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
|
||||
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); })
|
||||
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); })
|
||||
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); })
|
||||
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); })
|
||||
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); })
|
||||
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); })
|
||||
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); })
|
||||
.def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
|
||||
|
||||
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
|
||||
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
|
||||
@@ -703,6 +736,8 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||
.def("get_int32", &ir::builder::get_int32, ret::reference)
|
||||
.def("get_int64", &ir::builder::get_int64, ret::reference)
|
||||
.def("get_uint32", &ir::builder::get_uint32, ret::reference)
|
||||
.def("get_uint64", &ir::builder::get_uint64, ret::reference)
|
||||
.def("get_float16", &ir::builder::get_float16, ret::reference)
|
||||
.def("get_float32", &ir::builder::get_float32, ret::reference)
|
||||
.def("get_range", &ir::builder::get_range, ret::reference);
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import copy
|
||||
import itertools
|
||||
import re
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -10,17 +10,20 @@ from numpy.random import RandomState
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.code_gen import TensorWrapper, reinterpret
|
||||
|
||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
float_dtypes = ['float16', 'float32', 'float64']
|
||||
dtypes = int_dtypes + float_dtypes
|
||||
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||
|
||||
|
||||
def _bitwidth(dtype: str) -> int:
|
||||
# ex.: "int64" -> 64
|
||||
return int(re.search(r'(\d+)$', dtype).group(1))
|
||||
|
||||
|
||||
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
|
||||
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None):
|
||||
"""
|
||||
Override `rs` if you're calling this function twice and don't want the same
|
||||
result for both calls.
|
||||
@@ -30,9 +33,11 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
|
||||
if rs is None:
|
||||
rs = RandomState(seed=17)
|
||||
dtype = getattr(np, dtype_str)
|
||||
if dtype_str in int_dtypes:
|
||||
if dtype_str in int_dtypes + uint_dtypes:
|
||||
iinfo = np.iinfo(getattr(np, dtype_str))
|
||||
x = rs.randint(iinfo.min, iinfo.max, shape, dtype=dtype)
|
||||
low = iinfo.min if low is None else max(low, iinfo.min)
|
||||
high = iinfo.max if high is None else min(high, iinfo.max)
|
||||
x = rs.randint(low, high, shape, dtype=dtype)
|
||||
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
|
||||
return x
|
||||
elif dtype_str in float_dtypes:
|
||||
@@ -41,15 +46,31 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
|
||||
raise RuntimeError(f'Unknown dtype {dtype_str}')
|
||||
|
||||
|
||||
def to_triton(x: np.ndarray, device='cuda') -> torch.Tensor:
|
||||
# For now, this always converts to a torch tensor, but when we add unsigned
|
||||
# integers, it will also support TensorWrapper, since torch doesn't have
|
||||
# unsigned support.
|
||||
return torch.tensor(x, device=device)
|
||||
def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]:
|
||||
t = x.dtype.name
|
||||
if t in uint_dtypes:
|
||||
signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16"
|
||||
x_signed = x.astype(getattr(np, signed_type_name))
|
||||
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
|
||||
else:
|
||||
return torch.tensor(x, device=device)
|
||||
|
||||
|
||||
def torch_dtype_name(dtype) -> str:
|
||||
if isinstance(dtype, triton.language.dtype):
|
||||
return dtype.name
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
# 'torch.int64' -> 'int64'
|
||||
m = re.match(r'^torch\.(\w+)$', str(dtype))
|
||||
return m.group(1)
|
||||
else:
|
||||
raise TypeError(f'not a triton or torch dtype: {type(dtype)}')
|
||||
|
||||
|
||||
def to_numpy(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
if isinstance(x, TensorWrapper):
|
||||
return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
|
||||
elif isinstance(x, torch.Tensor):
|
||||
return x.cpu().numpy()
|
||||
else:
|
||||
raise ValueError(f"Not a triton-compatible tensor: {x}")
|
||||
@@ -103,18 +124,33 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
|
||||
Given two dtype strings, returns the numpy dtype Triton thinks binary
|
||||
operations on the two types should return. Returns None if the return value
|
||||
matches numpy. This is generally needed because Triton and pytorch return
|
||||
narrower floating point types than numpy in mixed operations.
|
||||
narrower floating point types than numpy in mixed operations, and because
|
||||
Triton follows C/C++ semantics around mixed signed/unsigned operations, and
|
||||
numpy/pytorch do not.
|
||||
"""
|
||||
overrides = {
|
||||
('float16', 'int16'): np.float16,
|
||||
('float16', 'int32'): np.float16,
|
||||
('float16', 'int64'): np.float16,
|
||||
('float16', 'uint16'): np.float16,
|
||||
('float16', 'uint32'): np.float16,
|
||||
('float16', 'uint64'): np.float16,
|
||||
('int8', 'uint8'): np.uint8,
|
||||
('int8', 'uint16'): np.uint16,
|
||||
('int8', 'uint32'): np.uint32,
|
||||
('int8', 'uint64'): np.uint64,
|
||||
('int16', 'uint16'): np.uint16,
|
||||
('int16', 'uint32'): np.uint32,
|
||||
('int16', 'uint64'): np.uint64,
|
||||
('int32', 'uint32'): np.uint32,
|
||||
('int32', 'uint64'): np.uint64,
|
||||
('int64', 'uint64'): np.uint64,
|
||||
}
|
||||
key = (a, b) if a < b else (b, a)
|
||||
return overrides.get(key)
|
||||
|
||||
|
||||
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda'):
|
||||
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
@@ -129,7 +165,7 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y=
|
||||
# inputs
|
||||
rs = RandomState(17)
|
||||
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
|
||||
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
|
||||
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high)
|
||||
if mode_x == 'nan':
|
||||
x[:] = float('nan')
|
||||
if mode_y == 'nan':
|
||||
@@ -158,6 +194,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
('int64', 'float16'),
|
||||
('int64', 'float32'),
|
||||
('int64', 'float64'),
|
||||
('uint16', 'float16'),
|
||||
('uint16', 'float32'),
|
||||
('uint32', 'float16'),
|
||||
('uint32', 'float32'),
|
||||
('uint64', 'float16'),
|
||||
('uint64', 'float32'),
|
||||
('uint64', 'float64'),
|
||||
]
|
||||
|
||||
# ---------------
|
||||
@@ -171,7 +214,7 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
])
|
||||
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f' x {op} y'
|
||||
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
|
||||
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
|
||||
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
|
||||
numpy_expr = 'np.fmod(x, y)'
|
||||
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
|
||||
@@ -179,15 +222,38 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
# are no native div or FRem operations on float16. Since we have to
|
||||
# convert anyway, we may as well take the accuracy bump.
|
||||
numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)'
|
||||
elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
||||
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
||||
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
||||
else:
|
||||
numpy_expr = None
|
||||
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
|
||||
with pytest.raises(AssertionError, match='Not equal to tolerance'):
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
elif (op in ('%', '/') and
|
||||
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
|
||||
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
|
||||
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
|
||||
else:
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y",
|
||||
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
|
||||
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
|
||||
)
|
||||
def test_floordiv(dtype_x, dtype_y, device='cuda'):
|
||||
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
|
||||
# through to //, so we have to use a nonstandard expression to get a
|
||||
# reference result for //.
|
||||
expr = 'x // y'
|
||||
numpy_expr = '((x - np.fmod(x, y)) / y)'
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test bitwise ops
|
||||
# ---------------
|
||||
@@ -199,13 +265,33 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
])
|
||||
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f'x {op} y'
|
||||
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
||||
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
||||
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
||||
else:
|
||||
numpy_expr = None
|
||||
if 'float' in dtype_x + dtype_y:
|
||||
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
|
||||
# The CompilationError must have been caused by a C++ exception with this text.
|
||||
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
|
||||
else:
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||
(dtype_x, dtype_y, op)
|
||||
for op in ['<<', '>>']
|
||||
for dtype_x in int_dtypes + uint_dtypes
|
||||
for dtype_y in int_dtypes + uint_dtypes
|
||||
])
|
||||
def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f'x {op} y'
|
||||
bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y))
|
||||
dtype_z = f'uint{bw}'
|
||||
numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})'
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65)
|
||||
|
||||
|
||||
# ---------------
|
||||
@@ -230,7 +316,13 @@ ops = ['==', '!=', '>', '<', '>=', '<=']
|
||||
])
|
||||
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
||||
expr = f'x {op} y'
|
||||
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
||||
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
||||
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
||||
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
||||
else:
|
||||
numpy_expr = None
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
@@ -238,9 +330,9 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, expr", [
|
||||
(dtype_x, ' -x') for dtype_x in dtypes
|
||||
] + [\
|
||||
] + [
|
||||
(dtype_x, ' ~x') for dtype_x in int_dtypes
|
||||
])
|
||||
])
|
||||
def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
_test_unary(dtype_x, expr, device=device)
|
||||
|
||||
@@ -275,8 +367,9 @@ def make_ptr_str(name, shape):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expr, dtype_str", [
|
||||
(f'x[{s}]', 'int32')
|
||||
(f'x[{s}]', d)
|
||||
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
||||
for d in ['int32', 'uint32', 'uint16']
|
||||
])
|
||||
def test_index1d(expr, dtype_str, device='cuda'):
|
||||
rank_x = expr.count(':')
|
||||
@@ -364,9 +457,9 @@ def test_tuples():
|
||||
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
||||
[
|
||||
('add', 'float16', mode),
|
||||
('add', 'int32', mode), ('add', 'float32', mode),
|
||||
('max', 'int32', mode), ('max', 'float32', mode),
|
||||
('min', 'int32', mode), ('min', 'float32', mode),
|
||||
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
|
||||
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
|
||||
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
@@ -409,7 +502,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
if exact:
|
||||
assert z_ref.item() == to_numpy(z_tri).item()
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.001)
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
|
||||
# ---------------
|
||||
@@ -423,8 +516,11 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
('float32', 'bfloat16', False),
|
||||
('bfloat16', 'float32', False),
|
||||
('float32', 'int32', True),
|
||||
]
|
||||
)
|
||||
] + [
|
||||
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
||||
] + [
|
||||
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||
x0 = 43 if dtype_x in int_dtypes else 43.5
|
||||
@@ -487,7 +583,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, shape, axis", [
|
||||
('float32', (1, 1024), 1)
|
||||
(dtype, (1, 1024), 1) for dtype in ['float32', 'uint32']
|
||||
])
|
||||
def test_reduce2d(dtype_str, shape, axis, device='cuda'):
|
||||
# triton kernel
|
||||
@@ -762,3 +858,43 @@ def test_noop(device='cuda'):
|
||||
pass
|
||||
x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
|
||||
kernel[(1, )](x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'),
|
||||
(2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'),
|
||||
(-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
pgm = kernel[(1, )](value, x)
|
||||
|
||||
# Parse out the type of the 'VALUE' parameter from the Triton IR.
|
||||
triton_ir = pgm.asm['ttir']
|
||||
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
|
||||
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
|
||||
assert ir_value_type == value_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, overflow",
|
||||
[(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
|
||||
)
|
||||
def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
|
||||
if overflow:
|
||||
with pytest.raises(RuntimeError, match='integer overflow'):
|
||||
kernel[(1, )](value, x)
|
||||
else:
|
||||
kernel[(1, )](value, x)
|
||||
|
@@ -147,6 +147,7 @@ def test_rand(size, seed, device='cuda'):
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
assert all((x >= 0) & (x <= 1))
|
||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test normal PRNG
|
||||
|
@@ -331,7 +331,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return triton.language.constexpr(not op)
|
||||
if isinstance(op, triton.language.core.constexpr):
|
||||
op = op.value
|
||||
# print(op)
|
||||
fn = {
|
||||
ast.USub: '__neg__',
|
||||
ast.UAdd: '__pos__',
|
||||
@@ -503,6 +502,7 @@ class Binary:
|
||||
self.shared_mem = shared_mem
|
||||
self.num_warps = num_warps
|
||||
|
||||
|
||||
class LoadedBinary:
|
||||
def __init__(self, device: int, bin: Binary):
|
||||
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
||||
@@ -571,24 +571,33 @@ class Kernel:
|
||||
torch.int16: 'i16',
|
||||
torch.int32: 'i32',
|
||||
torch.int64: 'i64',
|
||||
triton.language.uint8: 'u8',
|
||||
triton.language.uint16: 'u16',
|
||||
triton.language.uint32: 'u32',
|
||||
triton.language.uint64: 'u64',
|
||||
}
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
return type_names[obj.dtype]
|
||||
if isinstance(obj, triton.language.core.constexpr):
|
||||
obj = obj.value
|
||||
if isinstance(obj, int):
|
||||
if abs(obj) <= 0xffffffff:
|
||||
return 'I'
|
||||
return 'L'
|
||||
if -2**31 <= obj < 2**31:
|
||||
return 'i32'
|
||||
elif 2**31 <= obj < 2**32:
|
||||
return 'u32'
|
||||
elif -2**63 <= obj < 2**63:
|
||||
return 'i64'
|
||||
elif 2**63 <= obj < 2**64:
|
||||
return 'u64'
|
||||
else:
|
||||
raise ValueError(f'integer overflow representing {obj}')
|
||||
if isinstance(obj, float):
|
||||
return 'f'
|
||||
if isinstance(obj, bool):
|
||||
return 'B'
|
||||
if isinstance(obj, str):
|
||||
return 'str'
|
||||
assert False
|
||||
|
||||
|
||||
raise NotImplementedError(f'could not compute type name for {obj}')
|
||||
|
||||
@staticmethod
|
||||
def _to_triton_ir(context, obj):
|
||||
@@ -607,6 +616,10 @@ class Kernel:
|
||||
'i16': _triton.ir.type.get_int16,
|
||||
'i32': _triton.ir.type.get_int32,
|
||||
'i64': _triton.ir.type.get_int64,
|
||||
'u8': _triton.ir.type.get_uint8,
|
||||
'u16': _triton.ir.type.get_uint16,
|
||||
'u32': _triton.ir.type.get_uint32,
|
||||
'u64': _triton.ir.type.get_uint64,
|
||||
}
|
||||
# convert torch.Tensor to Triton IR pointers
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
@@ -1165,4 +1178,15 @@ class TensorWrapper:
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
return TensorWrapper(tensor, dtype)
|
||||
if isinstance(tensor, TensorWrapper):
|
||||
if dtype == tensor.base.dtype:
|
||||
# Reinterpreting to the original interpretation; return the base.
|
||||
return tensor.base
|
||||
else:
|
||||
# Reinterpreting a wrapped tensor to a different type.
|
||||
return TensorWrapper(tensor.base, dtype)
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
||||
|
@@ -9,9 +9,16 @@ def _to_ir(x, builder):
|
||||
if isinstance(x, bool):
|
||||
return builder.get_int1(x)
|
||||
elif isinstance(x, int):
|
||||
if x.__abs__() <= 2**31:
|
||||
if -2**31 <= x < 2**31:
|
||||
return builder.get_int32(x)
|
||||
return builder.get_int64(x)
|
||||
elif 2**31 <= x < 2**32:
|
||||
return builder.get_uint32(x)
|
||||
elif -2**63 <= x < 2**63:
|
||||
return builder.get_int64(x)
|
||||
elif 2**63 <= x < 2**64:
|
||||
return builder.get_uint64(x)
|
||||
else:
|
||||
raise RuntimeError(f'Nonrepresentable integer {x}.')
|
||||
elif isinstance(x, float):
|
||||
return builder.get_float32(x)
|
||||
elif isinstance(x, constexpr):
|
||||
@@ -83,6 +90,14 @@ class dtype:
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def cache_key_part(self) -> str:
|
||||
"""See cache_key_part() in triton.cc."""
|
||||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
return f'triton.language.{self.name}'
|
||||
|
||||
|
||||
class pointer_dtype:
|
||||
def __init__(self, element_ty):
|
||||
@@ -102,6 +117,10 @@ int8 = dtype(ir.type.get_int8)
|
||||
int16 = dtype(ir.type.get_int16)
|
||||
int32 = dtype(ir.type.get_int32)
|
||||
int64 = dtype(ir.type.get_int64)
|
||||
uint8 = dtype(ir.type.get_uint8)
|
||||
uint16 = dtype(ir.type.get_uint16)
|
||||
uint32 = dtype(ir.type.get_uint32)
|
||||
uint64 = dtype(ir.type.get_uint64)
|
||||
float8 = dtype(ir.type.get_fp8)
|
||||
float16 = dtype(ir.type.get_fp16)
|
||||
bfloat16 = dtype(ir.type.get_bf16)
|
||||
@@ -120,6 +139,10 @@ class block:
|
||||
if ir_type.is_int16(): return int16
|
||||
if ir_type.is_int32(): return int32
|
||||
if ir_type.is_int64(): return int64
|
||||
if ir_type.is_uint8(): return uint8
|
||||
if ir_type.is_uint16(): return uint16
|
||||
if ir_type.is_uint32(): return uint32
|
||||
if ir_type.is_uint64(): return uint64
|
||||
if ir_type.is_fp8(): return float8
|
||||
if ir_type.is_fp16(): return float16
|
||||
if ir_type.is_bf16(): return bfloat16
|
||||
|
Reference in New Issue
Block a user