uint8, uint16, uint32, and uint64 in kernels (#413)
A forthcoming PR will update the RNG to use these types. Also: - Add tests for the `//`, `<<`, and `>>` operators. - Change `TensorWrapper` to unwrap objects when the resulting object would be simpler. - Clean up `throw_unreachable`, since it was triggering compiler warnings.
This commit is contained in:
committed by
GitHub
parent
d8db0308cb
commit
0ab9d67bad
@@ -331,7 +331,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return triton.language.constexpr(not op)
|
||||
if isinstance(op, triton.language.core.constexpr):
|
||||
op = op.value
|
||||
# print(op)
|
||||
fn = {
|
||||
ast.USub: '__neg__',
|
||||
ast.UAdd: '__pos__',
|
||||
@@ -503,6 +502,7 @@ class Binary:
|
||||
self.shared_mem = shared_mem
|
||||
self.num_warps = num_warps
|
||||
|
||||
|
||||
class LoadedBinary:
|
||||
def __init__(self, device: int, bin: Binary):
|
||||
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
||||
@@ -571,24 +571,33 @@ class Kernel:
|
||||
torch.int16: 'i16',
|
||||
torch.int32: 'i32',
|
||||
torch.int64: 'i64',
|
||||
triton.language.uint8: 'u8',
|
||||
triton.language.uint16: 'u16',
|
||||
triton.language.uint32: 'u32',
|
||||
triton.language.uint64: 'u64',
|
||||
}
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
return type_names[obj.dtype]
|
||||
if isinstance(obj, triton.language.core.constexpr):
|
||||
obj = obj.value
|
||||
if isinstance(obj, int):
|
||||
if abs(obj) <= 0xffffffff:
|
||||
return 'I'
|
||||
return 'L'
|
||||
if -2**31 <= obj < 2**31:
|
||||
return 'i32'
|
||||
elif 2**31 <= obj < 2**32:
|
||||
return 'u32'
|
||||
elif -2**63 <= obj < 2**63:
|
||||
return 'i64'
|
||||
elif 2**63 <= obj < 2**64:
|
||||
return 'u64'
|
||||
else:
|
||||
raise ValueError(f'integer overflow representing {obj}')
|
||||
if isinstance(obj, float):
|
||||
return 'f'
|
||||
if isinstance(obj, bool):
|
||||
return 'B'
|
||||
if isinstance(obj, str):
|
||||
return 'str'
|
||||
assert False
|
||||
|
||||
|
||||
raise NotImplementedError(f'could not compute type name for {obj}')
|
||||
|
||||
@staticmethod
|
||||
def _to_triton_ir(context, obj):
|
||||
@@ -607,6 +616,10 @@ class Kernel:
|
||||
'i16': _triton.ir.type.get_int16,
|
||||
'i32': _triton.ir.type.get_int32,
|
||||
'i64': _triton.ir.type.get_int64,
|
||||
'u8': _triton.ir.type.get_uint8,
|
||||
'u16': _triton.ir.type.get_uint16,
|
||||
'u32': _triton.ir.type.get_uint32,
|
||||
'u64': _triton.ir.type.get_uint64,
|
||||
}
|
||||
# convert torch.Tensor to Triton IR pointers
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
@@ -1165,4 +1178,15 @@ class TensorWrapper:
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
return TensorWrapper(tensor, dtype)
|
||||
if isinstance(tensor, TensorWrapper):
|
||||
if dtype == tensor.base.dtype:
|
||||
# Reinterpreting to the original interpretation; return the base.
|
||||
return tensor.base
|
||||
else:
|
||||
# Reinterpreting a wrapped tensor to a different type.
|
||||
return TensorWrapper(tensor.base, dtype)
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
||||
|
Reference in New Issue
Block a user