uint8, uint16, uint32, and uint64 in kernels (#413)

A forthcoming PR will update the RNG to use these types.

Also:
- Add tests for the `//`, `<<`, and `>>` operators.
- Change `TensorWrapper` to unwrap objects when the resulting object would be simpler.
- Clean up `throw_unreachable`, since it was triggering compiler warnings.
This commit is contained in:
Madeleine Thompson
2022-01-05 15:27:17 -08:00
committed by GitHub
parent d8db0308cb
commit 0ab9d67bad
12 changed files with 444 additions and 110 deletions

View File

@@ -331,7 +331,6 @@ class CodeGenerator(ast.NodeVisitor):
return triton.language.constexpr(not op)
if isinstance(op, triton.language.core.constexpr):
op = op.value
# print(op)
fn = {
ast.USub: '__neg__',
ast.UAdd: '__pos__',
@@ -503,6 +502,7 @@ class Binary:
self.shared_mem = shared_mem
self.num_warps = num_warps
class LoadedBinary:
def __init__(self, device: int, bin: Binary):
module, kernel = _triton.code_gen.load_binary(bin.backend,
@@ -571,24 +571,33 @@ class Kernel:
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
triton.language.uint8: 'u8',
triton.language.uint16: 'u16',
triton.language.uint32: 'u32',
triton.language.uint64: 'u64',
}
if hasattr(obj, 'data_ptr'):
return type_names[obj.dtype]
if isinstance(obj, triton.language.core.constexpr):
obj = obj.value
if isinstance(obj, int):
if abs(obj) <= 0xffffffff:
return 'I'
return 'L'
if -2**31 <= obj < 2**31:
return 'i32'
elif 2**31 <= obj < 2**32:
return 'u32'
elif -2**63 <= obj < 2**63:
return 'i64'
elif 2**63 <= obj < 2**64:
return 'u64'
else:
raise ValueError(f'integer overflow representing {obj}')
if isinstance(obj, float):
return 'f'
if isinstance(obj, bool):
return 'B'
if isinstance(obj, str):
return 'str'
assert False
raise NotImplementedError(f'could not compute type name for {obj}')
@staticmethod
def _to_triton_ir(context, obj):
@@ -607,6 +616,10 @@ class Kernel:
'i16': _triton.ir.type.get_int16,
'i32': _triton.ir.type.get_int32,
'i64': _triton.ir.type.get_int64,
'u8': _triton.ir.type.get_uint8,
'u16': _triton.ir.type.get_uint16,
'u32': _triton.ir.type.get_uint32,
'u64': _triton.ir.type.get_uint64,
}
# convert torch.Tensor to Triton IR pointers
if hasattr(obj, 'data_ptr'):
@@ -1165,4 +1178,15 @@ class TensorWrapper:
def reinterpret(tensor, dtype):
return TensorWrapper(tensor, dtype)
if isinstance(tensor, TensorWrapper):
if dtype == tensor.base.dtype:
# Reinterpreting to the original interpretation; return the base.
return tensor.base
else:
# Reinterpreting a wrapped tensor to a different type.
return TensorWrapper(tensor.base, dtype)
elif isinstance(tensor, torch.Tensor):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')