uint8, uint16, uint32, and uint64 in kernels (#413)

A forthcoming PR will update the RNG to use these types.

Also:
- Add tests for the `//`, `<<`, and `>>` operators.
- Change `TensorWrapper` to unwrap objects when the resulting object would be simpler.
- Clean up `throw_unreachable`, since it was triggering compiler warnings.
This commit is contained in:
Madeleine Thompson
2022-01-05 15:27:17 -08:00
committed by GitHub
parent d8db0308cb
commit 0ab9d67bad
12 changed files with 444 additions and 110 deletions

View File

@@ -109,6 +109,24 @@ std::string pow2_divisor(long N){
return "1";
}
// Returns something like "int16", whether dtype is a torch.dtype or
// triton.language.dtype.
std::string dtype_cache_key_part(const py::object& dtype) {
if (py::hasattr(dtype, "cache_key_part")) {
// Presumed to be a triton.language.dtype.
return std::string(py::str(py::getattr(dtype, "cache_key_part")));
} else {
// Remove 'torch.' prefix from repr of torch.dtype.
py::object repr = py::repr(dtype);
size_t repr_len = PyUnicode_GET_LENGTH(repr.ptr());
const char* repr_ptr = (const char*)PyUnicode_1BYTE_DATA(repr.ptr());
if (repr_len <= 6 || strncmp(repr_ptr, "torch.", 6)) {
throw std::logic_error("invalid dtype: " + std::string(repr_ptr, repr_len));
}
return std::string(repr_ptr + 6, repr_len - 6);
}
}
// Launch
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
@@ -136,22 +154,34 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
cache_key += "1";
continue;
}
// long and int have different kernels
if(!overflow & (std::abs(value) <= 0xffffffff)){
// int32, uint32, int64, and uint64 have different kernels
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
cache_key += "int32";
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
}
else{
} else if (!overflow && 0x8000'0000LL <= value && value <= 0xFFFF'FFFFLL) {
cache_key += "uint32";
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow) {
cache_key += "int64";
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
if(overflow){
unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr);
std::memcpy(&value, &uvalue, 8);
}
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
} else {
if (PyErr_Occurred()) {
throw std::logic_error("An error occurred?");
}
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
if (PyErr_Occurred()) {
throw std::runtime_error("integer overflow in argument: " + std::string(py::str(arg)));
}
cache_key += "uint64";
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &unsigned_value, 8);
params_ptr += 8;
}
if(!specialize)
continue;
@@ -185,12 +215,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
py::object dtype = arg.attr("dtype");
py::object repr = py::repr(dtype);
assert(!strncmp((const char*)PyUnicode_1BYTE_DATA(repr.ptr()), "torch.", 6));
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
cache_key += std::string(start, len);
cache_key += dtype_cache_key_part(arg.attr("dtype"));
cache_key += "*";
cache_key += "[multipleof(";
cache_key += pow2_divisor(value);
@@ -628,6 +653,10 @@ void init_triton_ir(py::module &&m) {
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
.def("get_int32", &ir::type::get_int32_ty, ret::reference)
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
.def("get_uint8", &ir::type::get_uint8_ty, ret::reference)
.def("get_uint16", &ir::type::get_uint16_ty, ret::reference)
.def("get_uint32", &ir::type::get_uint32_ty, ret::reference)
.def("get_uint64", &ir::type::get_uint64_ty, ret::reference)
.def("is_void", &ir::type::is_void_ty)
.def("is_fp8", &ir::type::is_fp8_ty)
@@ -635,11 +664,15 @@ void init_triton_ir(py::module &&m) {
.def("is_bf16", &ir::type::is_bf16_ty)
.def("is_fp32", &ir::type::is_fp32_ty)
.def("is_fp64", &ir::type::is_fp64_ty)
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); })
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); })
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); })
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); })
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); })
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); })
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); })
.def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); })
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
@@ -703,6 +736,8 @@ void init_triton_ir(py::module &&m) {
.def("get_int1", &ir::builder::get_int1, ret::reference)
.def("get_int32", &ir::builder::get_int32, ret::reference)
.def("get_int64", &ir::builder::get_int64, ret::reference)
.def("get_uint32", &ir::builder::get_uint32, ret::reference)
.def("get_uint64", &ir::builder::get_uint64, ret::reference)
.def("get_float16", &ir::builder::get_float16, ret::reference)
.def("get_float32", &ir::builder::get_float32, ret::reference)
.def("get_range", &ir::builder::get_range, ret::reference);

View File

@@ -1,7 +1,7 @@
import copy
import itertools
import re
from typing import Optional
from typing import Optional, Union
import numpy as np
import pytest
@@ -10,17 +10,20 @@ from numpy.random import RandomState
import triton
import triton.language as tl
from triton.code_gen import TensorWrapper, reinterpret
int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + float_dtypes
dtypes = int_dtypes + uint_dtypes + float_dtypes
def _bitwidth(dtype: str) -> int:
# ex.: "int64" -> 64
return int(re.search(r'(\d+)$', dtype).group(1))
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None):
"""
Override `rs` if you're calling this function twice and don't want the same
result for both calls.
@@ -30,9 +33,11 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
if rs is None:
rs = RandomState(seed=17)
dtype = getattr(np, dtype_str)
if dtype_str in int_dtypes:
if dtype_str in int_dtypes + uint_dtypes:
iinfo = np.iinfo(getattr(np, dtype_str))
x = rs.randint(iinfo.min, iinfo.max, shape, dtype=dtype)
low = iinfo.min if low is None else max(low, iinfo.min)
high = iinfo.max if high is None else min(high, iinfo.max)
x = rs.randint(low, high, shape, dtype=dtype)
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
return x
elif dtype_str in float_dtypes:
@@ -41,15 +46,31 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
raise RuntimeError(f'Unknown dtype {dtype_str}')
def to_triton(x: np.ndarray, device='cuda') -> torch.Tensor:
# For now, this always converts to a torch tensor, but when we add unsigned
# integers, it will also support TensorWrapper, since torch doesn't have
# unsigned support.
return torch.tensor(x, device=device)
def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]:
t = x.dtype.name
if t in uint_dtypes:
signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16"
x_signed = x.astype(getattr(np, signed_type_name))
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
else:
return torch.tensor(x, device=device)
def torch_dtype_name(dtype) -> str:
if isinstance(dtype, triton.language.dtype):
return dtype.name
elif isinstance(dtype, torch.dtype):
# 'torch.int64' -> 'int64'
m = re.match(r'^torch\.(\w+)$', str(dtype))
return m.group(1)
else:
raise TypeError(f'not a triton or torch dtype: {type(dtype)}')
def to_numpy(x):
if isinstance(x, torch.Tensor):
if isinstance(x, TensorWrapper):
return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
elif isinstance(x, torch.Tensor):
return x.cpu().numpy()
else:
raise ValueError(f"Not a triton-compatible tensor: {x}")
@@ -103,18 +124,33 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
Given two dtype strings, returns the numpy dtype Triton thinks binary
operations on the two types should return. Returns None if the return value
matches numpy. This is generally needed because Triton and pytorch return
narrower floating point types than numpy in mixed operations.
narrower floating point types than numpy in mixed operations, and because
Triton follows C/C++ semantics around mixed signed/unsigned operations, and
numpy/pytorch do not.
"""
overrides = {
('float16', 'int16'): np.float16,
('float16', 'int32'): np.float16,
('float16', 'int64'): np.float16,
('float16', 'uint16'): np.float16,
('float16', 'uint32'): np.float16,
('float16', 'uint64'): np.float16,
('int8', 'uint8'): np.uint8,
('int8', 'uint16'): np.uint16,
('int8', 'uint32'): np.uint32,
('int8', 'uint64'): np.uint64,
('int16', 'uint16'): np.uint16,
('int16', 'uint32'): np.uint32,
('int16', 'uint64'): np.uint64,
('int32', 'uint32'): np.uint32,
('int32', 'uint64'): np.uint64,
('int64', 'uint64'): np.uint64,
}
key = (a, b) if a < b else (b, a)
return overrides.get(key)
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda'):
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
@@ -129,7 +165,7 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y=
# inputs
rs = RandomState(17)
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high)
if mode_x == 'nan':
x[:] = float('nan')
if mode_y == 'nan':
@@ -158,6 +194,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
('int64', 'float16'),
('int64', 'float32'),
('int64', 'float64'),
('uint16', 'float16'),
('uint16', 'float32'),
('uint32', 'float16'),
('uint32', 'float32'),
('uint64', 'float16'),
('uint64', 'float32'),
('uint64', 'float64'),
]
# ---------------
@@ -171,7 +214,7 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
])
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
expr = f' x {op} y'
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
numpy_expr = 'np.fmod(x, y)'
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
@@ -179,15 +222,38 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
# are no native div or FRem operations on float16. Since we have to
# convert anyway, we may as well take the accuracy bump.
numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)'
elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
else:
numpy_expr = None
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
with pytest.raises(AssertionError, match='Not equal to tolerance'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
elif (op in ('%', '/') and
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
else:
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
@pytest.mark.parametrize("dtype_x, dtype_y",
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
)
def test_floordiv(dtype_x, dtype_y, device='cuda'):
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
# through to //, so we have to use a nonstandard expression to get a
# reference result for //.
expr = 'x // y'
numpy_expr = '((x - np.fmod(x, y)) / y)'
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
# ---------------
# test bitwise ops
# ---------------
@@ -199,13 +265,33 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
])
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
expr = f'x {op} y'
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
else:
numpy_expr = None
if 'float' in dtype_x + dtype_y:
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
# The CompilationError must have been caused by a C++ exception with this text.
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
else:
_test_binary(dtype_x, dtype_y, expr, device=device)
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op)
for op in ['<<', '>>']
for dtype_x in int_dtypes + uint_dtypes
for dtype_y in int_dtypes + uint_dtypes
])
def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
expr = f'x {op} y'
bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y))
dtype_z = f'uint{bw}'
numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})'
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65)
# ---------------
@@ -230,7 +316,13 @@ ops = ['==', '!=', '>', '<', '>=', '<=']
])
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
expr = f'x {op} y'
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
else:
numpy_expr = None
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
# ---------------
@@ -238,9 +330,9 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
# ---------------
@pytest.mark.parametrize("dtype_x, expr", [
(dtype_x, ' -x') for dtype_x in dtypes
] + [\
] + [
(dtype_x, ' ~x') for dtype_x in int_dtypes
])
])
def test_unary_op(dtype_x, expr, device='cuda'):
_test_unary(dtype_x, expr, device=device)
@@ -275,8 +367,9 @@ def make_ptr_str(name, shape):
@pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', 'int32')
(f'x[{s}]', d)
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
for d in ['int32', 'uint32', 'uint16']
])
def test_index1d(expr, dtype_str, device='cuda'):
rank_x = expr.count(':')
@@ -364,9 +457,9 @@ def test_tuples():
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
[
('add', 'float16', mode),
('add', 'int32', mode), ('add', 'float32', mode),
('max', 'int32', mode), ('max', 'float32', mode),
('min', 'int32', mode), ('min', 'float32', mode),
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
@@ -409,7 +502,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
if exact:
assert z_ref.item() == to_numpy(z_tri).item()
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.001)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# ---------------
@@ -423,8 +516,11 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
('float32', 'bfloat16', False),
('bfloat16', 'float32', False),
('float32', 'int32', True),
]
)
] + [
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
] + [
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
x0 = 43 if dtype_x in int_dtypes else 43.5
@@ -487,7 +583,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'):
@pytest.mark.parametrize("dtype_str, shape, axis", [
('float32', (1, 1024), 1)
(dtype, (1, 1024), 1) for dtype in ['float32', 'uint32']
])
def test_reduce2d(dtype_str, shape, axis, device='cuda'):
# triton kernel
@@ -762,3 +858,43 @@ def test_noop(device='cuda'):
pass
x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
kernel[(1, )](x)
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'),
(2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
x = torch.tensor([3.14159], device='cuda')
pgm = kernel[(1, )](value, x)
# Parse out the type of the 'VALUE' parameter from the Triton IR.
triton_ir = pgm.asm['ttir']
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
assert ir_value_type == value_type
@pytest.mark.parametrize(
"value, overflow",
[(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
)
def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
x = torch.tensor([3.14159], device='cuda')
if overflow:
with pytest.raises(RuntimeError, match='integer overflow'):
kernel[(1, )](value, x)
else:
kernel[(1, )](value, x)

View File

@@ -147,6 +147,7 @@ def test_rand(size, seed, device='cuda'):
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
assert all((x >= 0) & (x <= 1))
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG

View File

@@ -331,7 +331,6 @@ class CodeGenerator(ast.NodeVisitor):
return triton.language.constexpr(not op)
if isinstance(op, triton.language.core.constexpr):
op = op.value
# print(op)
fn = {
ast.USub: '__neg__',
ast.UAdd: '__pos__',
@@ -503,6 +502,7 @@ class Binary:
self.shared_mem = shared_mem
self.num_warps = num_warps
class LoadedBinary:
def __init__(self, device: int, bin: Binary):
module, kernel = _triton.code_gen.load_binary(bin.backend,
@@ -571,24 +571,33 @@ class Kernel:
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
triton.language.uint8: 'u8',
triton.language.uint16: 'u16',
triton.language.uint32: 'u32',
triton.language.uint64: 'u64',
}
if hasattr(obj, 'data_ptr'):
return type_names[obj.dtype]
if isinstance(obj, triton.language.core.constexpr):
obj = obj.value
if isinstance(obj, int):
if abs(obj) <= 0xffffffff:
return 'I'
return 'L'
if -2**31 <= obj < 2**31:
return 'i32'
elif 2**31 <= obj < 2**32:
return 'u32'
elif -2**63 <= obj < 2**63:
return 'i64'
elif 2**63 <= obj < 2**64:
return 'u64'
else:
raise ValueError(f'integer overflow representing {obj}')
if isinstance(obj, float):
return 'f'
if isinstance(obj, bool):
return 'B'
if isinstance(obj, str):
return 'str'
assert False
raise NotImplementedError(f'could not compute type name for {obj}')
@staticmethod
def _to_triton_ir(context, obj):
@@ -607,6 +616,10 @@ class Kernel:
'i16': _triton.ir.type.get_int16,
'i32': _triton.ir.type.get_int32,
'i64': _triton.ir.type.get_int64,
'u8': _triton.ir.type.get_uint8,
'u16': _triton.ir.type.get_uint16,
'u32': _triton.ir.type.get_uint32,
'u64': _triton.ir.type.get_uint64,
}
# convert torch.Tensor to Triton IR pointers
if hasattr(obj, 'data_ptr'):
@@ -1165,4 +1178,15 @@ class TensorWrapper:
def reinterpret(tensor, dtype):
return TensorWrapper(tensor, dtype)
if isinstance(tensor, TensorWrapper):
if dtype == tensor.base.dtype:
# Reinterpreting to the original interpretation; return the base.
return tensor.base
else:
# Reinterpreting a wrapped tensor to a different type.
return TensorWrapper(tensor.base, dtype)
elif isinstance(tensor, torch.Tensor):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')

View File

@@ -9,9 +9,16 @@ def _to_ir(x, builder):
if isinstance(x, bool):
return builder.get_int1(x)
elif isinstance(x, int):
if x.__abs__() <= 2**31:
if -2**31 <= x < 2**31:
return builder.get_int32(x)
return builder.get_int64(x)
elif 2**31 <= x < 2**32:
return builder.get_uint32(x)
elif -2**63 <= x < 2**63:
return builder.get_int64(x)
elif 2**63 <= x < 2**64:
return builder.get_uint64(x)
else:
raise RuntimeError(f'Nonrepresentable integer {x}.')
elif isinstance(x, float):
return builder.get_float32(x)
elif isinstance(x, constexpr):
@@ -83,6 +90,14 @@ class dtype:
def __str__(self):
return self.name
@property
def cache_key_part(self) -> str:
"""See cache_key_part() in triton.cc."""
return self.name
def __repr__(self):
return f'triton.language.{self.name}'
class pointer_dtype:
def __init__(self, element_ty):
@@ -102,6 +117,10 @@ int8 = dtype(ir.type.get_int8)
int16 = dtype(ir.type.get_int16)
int32 = dtype(ir.type.get_int32)
int64 = dtype(ir.type.get_int64)
uint8 = dtype(ir.type.get_uint8)
uint16 = dtype(ir.type.get_uint16)
uint32 = dtype(ir.type.get_uint32)
uint64 = dtype(ir.type.get_uint64)
float8 = dtype(ir.type.get_fp8)
float16 = dtype(ir.type.get_fp16)
bfloat16 = dtype(ir.type.get_bf16)
@@ -120,6 +139,10 @@ class block:
if ir_type.is_int16(): return int16
if ir_type.is_int32(): return int32
if ir_type.is_int64(): return int64
if ir_type.is_uint8(): return uint8
if ir_type.is_uint16(): return uint16
if ir_type.is_uint32(): return uint32
if ir_type.is_uint64(): return uint64
if ir_type.is_fp8(): return float8
if ir_type.is_fp16(): return float16
if ir_type.is_bf16(): return bfloat16