Files
triton/python/triton/language/core.py
Madeleine Thompson e575ae3443 [FRONTEND] Minor accumulated style and warning fixes (#388)
- Fix some whitespace.
- Make an undeclared dependency on `pytest` explicit.
- Fix deprecated `description-file` use.
- `#ifdef` out a deprecated `PyEval_InitThreads` call.
- Use a slightly different numpy invocation in `test_random.py` to quiet down overflow warnings in tests.
- Fix a deprecated cast in `test_core.py`.
- Suppress a warning about `visit_Constant` in Python 3.9+; we can't migrate yet because it'd break Python 3.6 and 3.7.
- Use chained exceptions for `CompilationError` rather than rolling our own; it makes the error messages nicer.
- Add a `__str__` for `tl.dtype` to make debugging kernels easier; it lets you `print` a dtype to see what type was inferred.
- Fix a few bad escapes.
2021-12-10 15:19:20 -08:00

860 lines
23 KiB
Python

import triton
from triton._C.libtriton.triton import ir
from triton._C.libtriton.triton import frontend
from functools import wraps
# convert block/dtype to ir values
def _to_ir(x, builder):
if isinstance(x, bool):
return builder.get_int1(x)
elif isinstance(x, int):
if x.__abs__() <= 2**31:
return builder.get_int32(x)
return builder.get_int64(x)
elif isinstance(x, float):
return builder.get_float32(x)
elif isinstance(x, constexpr):
return _to_ir(x.value, builder)
elif isinstance(x, block):
return x.handle
elif isinstance(x, dtype):
return x.handle(builder)
return x
def _patch(fn):
def _from_ir(x):
if isinstance(x, ir.value):
if x.type.is_void():
return None
return block(x)
return x
def wrapper(*args, **kwargs):
builder = args[-1]
assert isinstance(builder, ir.builder)
args = [_to_ir(x, builder) for x in args]
# for i, arg in enumerate(args):
# if arg is None:
# raise ValueError(f"Unexpected `None` at position {i} for function {fn.__name__}")
kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()}
ret = fn(*args, **kwargs)
if isinstance(ret, tuple):
return map(_from_ir, ret)
return _from_ir(ret)
return wrapper
for name in dir(frontend):
fn = getattr(frontend, name)
if callable(fn):
setattr(frontend, name, _patch(fn))
def builtin(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if '_builder' not in kwargs or \
kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return fn(*args, **kwargs)
return wrapper
class dtype:
def __init__(self, init):
self.init = init
def handle(self, builder):
ctx = builder.context
return self.init(ctx)
def __str__(self):
return f"dtype({self.init.__name__})"
class pointer_dtype:
def __init__(self, element_ty):
self.element_ty = element_ty
def handle(self, builder):
return ir.type.make_ptr(self.element_ty.handle(builder), 1)
# scalar types
int1 = dtype(ir.type.get_int1)
int8 = dtype(ir.type.get_int8)
int16 = dtype(ir.type.get_int16)
int32 = dtype(ir.type.get_int32)
int64 = dtype(ir.type.get_int64)
float8 = dtype(ir.type.get_fp8)
float16 = dtype(ir.type.get_fp16)
bfloat16 = dtype(ir.type.get_bf16)
float32 = dtype(ir.type.get_fp32)
float64 = dtype(ir.type.get_fp64)
# pointer types
pi32_t = pointer_dtype(int32)
class block:
@staticmethod
def _init_dtype(ir_type):
# primitive type
if ir_type.is_int1(): return int1
if ir_type.is_int8(): return int8
if ir_type.is_int16(): return int16
if ir_type.is_int32(): return int32
if ir_type.is_int64(): return int64
if ir_type.is_fp8(): return float8
if ir_type.is_fp16(): return float16
if ir_type.is_bf16(): return bfloat16
if ir_type.is_fp32(): return float32
if ir_type.is_fp64(): return float64
# pointer type
if ir_type.is_ptr():
element_ty = block._init_dtype(ir_type.element)
return pointer_dtype(element_ty)
raise ValueError(f"Unsupported type {ir_type}")
def __init__(self, handle):
# IR handle
self.handle = handle
# Block shape
self.shape = (1, )
if self.handle.type.is_block():
self.shape = self.handle.type.shape
self.numel = 1
for s in self.shape:
self.numel *= s
# Data-type wrapper
self.dtype = block._init_dtype(self.handle.type.scalar)
@builtin
def __add__(self, other, _builder=None):
return frontend.add(self, other, _builder)
def __radd__(self, other, _builder=None):
return self.__add__(other, _builder=_builder)
@builtin
def __sub__(self, other, _builder=None):
return frontend.sub(self, other, _builder)
def __rsub__(self, other, _builder=None):
return frontend.sub(other, self, _builder)
@builtin
def __mul__(self, other, _builder=None):
return frontend.mul(self, other, _builder)
def __rmul__(self, other, _builder=None):
return self.__mul__(other, _builder=_builder)
@builtin
def __truediv__(self, other, _builder=None):
return frontend.truediv(self, other, _builder)
def __rtruediv__(self, other, _builder=None):
return frontend.truediv(other, self, _builder)
@builtin
def __floordiv__(self, other, _builder=None):
return frontend.floordiv(self, other, _builder)
@builtin
def __mod__(self, other, _builder=None):
return frontend.mod(self, other, _builder)
# unary operators
@builtin
def __neg__(self, _builder=None):
return frontend.minus(self, _builder)
@builtin
def __invert__(self, _builder=None):
return frontend.invert(self, _builder)
# bitwise operators
@builtin
def __and__(self, other, _builder=None):
return frontend.and_(self, other, _builder)
@builtin
def __or__(self, other, _builder=None):
return frontend.or_(self, other, _builder)
@builtin
def __xor__(self, other, _builder=None):
return frontend.xor_(self, other, _builder)
@builtin
def __lshift__(self, other, _builder=None):
return frontend.shl(self, other, _builder)
@builtin
def __rshift__(self, other, _builder=None):
return frontend.lshr(self, other, _builder)
# comparison operators
# >
@builtin
def __gt__(self, other, _builder=None):
return frontend.greater_than(self, other, _builder)
@builtin
def __rgt__(self, other, _builder=None):
return frontend.greater_than(other, self, _builder)
# >=
@builtin
def __ge__(self, other, _builder=None):
return frontend.greater_equal(self, other, _builder)
def __rge__(self, other, _builder=None):
return frontend.greater_equal(other, self, _builder)
# <
@builtin
def __lt__(self, other, _builder=None):
return frontend.less_than(self, other, _builder)
@builtin
def __rlt__(self, other, _builder=None):
return frontend.less_than(other, self, _builder)
# <=
@builtin
def __le__(self, other, _builder=None):
return frontend.less_equal(self, other, _builder)
@builtin
def __rle__(self, other, _builder=None):
return frontend.less_equal(other, self, _builder)
# ==
@builtin
def __eq__(self, other, _builder=None):
return frontend.equal(self, other, _builder)
@builtin
def __ne__(self, other, _builder=None):
return frontend.not_equal(self, other, _builder)
@builtin
def __getitem__(self, slices, _builder=None):
if isinstance(slices, slice):
slices = [slices]
src_shape = self.shape
dst_shape = []
curr = 0
for sl in slices:
if sl == None:
dst_shape.append(1)
elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr])
curr += 1
ret = frontend.reshape(self, dst_shape, _builder)
return ret
@builtin
def to(self, dtype, bitcast=False, _builder=None):
dtype = dtype.handle(_builder)
if bitcast:
return frontend.bitcast(self, dtype, _builder)
return frontend.cast(self, dtype, _builder)
# -----------------------
# constexpr
# -----------------------
class constexpr:
"""
This class is used to store a value that is known at compile-time.
"""
def __init__(self, value):
self.value = value
def __add__(self, other):
return self.value + other.value
def __radd__(self, other):
return other.value + self.value
def __sub__(self, other):
return self.value - other.value
def __rsub__(self, other):
return other.value - self.value
def __mul__(self, other):
return self.value * other.value
def __rmul__(self, other):
return other.value * self.value
def __truediv__(self, other):
return self.value / other.value
def __rtruediv__(self, other):
return other.value / self.value
def __floordiv__(self, other):
return self.value // other.value
def __rfloordiv__(self, other):
return other.value // self.value
#
def __gt__(self, other):
return self.value > other.value
def __rgt__(self, other):
return other.value > self.value
def __ge__(self, other):
return self.value >= other.value
def __rge__(self, other):
return other.value >= self.value
def __lt__(self, other):
return self.value < other.value
def __rlt__(self, other):
return other.value < self.value
def __le__(self, other):
return self.value <= other.value
def __rle__(self, other):
return other.value <= self.value
def __eq__(self, other):
return self.value == other.value
def __ne__(self, other):
return self.value != other.value
def __bool__(self):
return bool(self.value)
def __call__(self, *args, **kwds):
return self.value(*args, **kwds)
# -----------------------
# SPMD Programming Model
# -----------------------
@builtin
def program_id(axis, _builder=None):
"""
Returns the id of the current program instance along the given :code:`axis`.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
"""
# if axis == -1:
# pid0 = frontend.program_id(0, _builder)
# pid1 = frontend.program_id(1, _builder)
# pid2 = frontend.program_id(2, _builder)
# npg0 = frontend.num_programs(0, _builder)
# npg1 = frontend.num_programs(0, _builder)
# return pid0 + pid1*npg0 + pid2*npg0*npg1
return frontend.program_id(axis, _builder)
@builtin
def num_programs(axis, _builder=None):
"""
Returns the number of program instances launched along the given :code:`axis`.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
"""
return frontend.num_programs(axis, _builder)
# -----------------------
# Block Initialization
# -----------------------
@builtin
def arange(start, end, _builder=None):
"""
Returns contiguous values within the open interval [:code:`start`, :code:`end`).
:param start: Start of the interval. Must be a power of two.
:type start: int
:param stop: End of the interval. Must be a power of two >= start.
:type stop: int
"""
return frontend.arange(start, end, _builder)
@builtin
def zeros(shape, dtype, _builder=None):
"""
Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
:type dtype: DType
"""
for i, d in enumerate(shape):
if not isinstance(d, constexpr):
raise TypeError(f"Shape element {i} must have type `constexpr`")
if not isinstance(d.value, int):
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
shape = [x.value for x in shape]
return frontend.zeros(shape, dtype, _builder)
# -----------------------
# Shape Manipulation
# -----------------------
@builtin
def broadcast(input, other, _builder=None):
"""
Tries to broadcast the two given blocks to a common compatible shape.
:param input: The first input block.
:type input: Block
:param other: The second input block.
:type other: Block
"""
return frontend.broadcast(input, other, _builder)
@builtin
def broadcast_to(input, shape, _builder=None):
"""
Tries to broadcast the given block to a new :code:`shape`.
:param input: The input block.
:type input: Block
:param shape: The desired shape.
:type shape: Tuple[int]
"""
return frontend.broadcast_to(input, shape, _builder)
@builtin
def cat(input, other, _builder=None):
"""
Concatenate the given blocks
:param input: The first input block.
:type input:
:param other: The second input block.
:type other:
"""
return frontend.cat(input, other, _builder)
@builtin
def reshape(input, shape, _builder=None):
"""
Tries to reshape the given block to a new shape.
:param input: The input block.
:type input:
:param shape: The desired shape.
:type shape: Tuple[int]
"""
return frontend.reshape(input, shape, _builder)
# -----------------------
# Linear Algebra
# -----------------------
@builtin
def dot(input, other, _builder=None):
"""
Returns the matrix product of two blocks.
The two blocks must be two dimensionals and have compatible inner dimensions.
:param input: The first block to be multiplied.
:type input: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
:param other: The second block to be multiplied.
:type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
"""
return frontend.dot(input, other, _builder)
# -----------------------
# Non-Atomic Memory Operations
# -----------------------
@builtin
def load(pointer, mask=None, other=None, cache_modifier="", _builder=None):
"""
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
:code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`.
:code:`other` is implicitly typecast to :code:`pointer.dtype.element_ty`.
:param pointer: Pointers to the data to be loaded.
:type pointer: Block of dtype=triton.PointerDType
:param mask: if mask[idx] is false, do not load the data at address :code:`pointer[idx]`.
:type mask: Block of triton.int1, optional
:param other: if mask[idx] is false, return other[idx]
:type other: Block, optional
:param cache_modifier: changes cache option in nvidia ptx
'type cache_modifier: str, optional
"""
return frontend.load(pointer, mask, other, cache_modifier, _builder)
@builtin
def store(pointer, value, mask=None, _builder=None):
"""
Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
:code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`.
:param pointer: The memory locations where the elements of :code:`value` are stored.
:type pointer: Block of dtype=triton.PointerDType
:param value: The block of elements to be stored.
:type value: Block
:param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`.
:type mask: Block of triton.int1, optional
"""
return frontend.store(pointer, value, mask, _builder)
# -----------------------
# Atomic Memory Operations
# -----------------------
def _add_atomic_docstr(name):
def _decorator(func):
docstr = """
Performs an atomic {name} at the memory location specified by :code:`pointer`.
Return the data stored at :code:`pointer` before the atomic operation.
:param pointer: The memory locations to compare-and-swap.
:type pointer: Block of dtype=triton.PointerDType
:param cmp: The values expected to be found in the atomic object
:type cmp: Block of dtype=`pointer.dtype.element_ty`
:param val: The values to copy in case the expected value matches the contained value.
:type val: Block of dtype=`pointer.dtype.element_ty`
"""
func.__doc__ = docstr.format(name=name)
return func
return _decorator
@builtin
@_add_atomic_docstr("compare-and-swap")
def atomic_cas(pointer, cmp, val, _builder=None):
return frontend.atomic_cas(pointer, cmp, val, _builder)
@builtin
@_add_atomic_docstr("exchange")
def atomic_xchg(pointer, val, mask=None, _builder=None):
return frontend.atomic_xchg(pointer, val, mask, _builder)
@builtin
@_add_atomic_docstr("add")
def atomic_add(pointer, val, mask=None, _builder=None):
return frontend.atomic_add(pointer, val, mask, _builder)
@builtin
@_add_atomic_docstr("max")
def atomic_max(pointer, val, mask=None, _builder=None):
return frontend.atomic_max(pointer, val, mask, _builder)
@builtin
@_add_atomic_docstr("min")
def atomic_min(pointer, val, mask=None, _builder=None):
return frontend.atomic_min(pointer, val, mask, _builder)
@builtin
@_add_atomic_docstr("logical and")
def atomic_and(pointer, val, mask=None, _builder=None):
return frontend.atomic_and(pointer, val, mask, _builder)
@builtin
@_add_atomic_docstr("logical or")
def atomic_or(pointer, val, mask=None, _builder=None):
return frontend.atomic_or(pointer, val, mask, _builder)
@builtin
@_add_atomic_docstr("logical xor")
def atomic_xor(pointer, val, mask=None, _builder=None):
return frontend.atomic_xor(pointer, val, mask, _builder)
# -----------------------
# Conditioning
# -----------------------
@builtin
def where(condition, x, y, _builder=None):
"""
Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`.
If you want to avoid unintented memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead.
The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`.
:code:`x` and :code:`y` must have the data type.
:param condition: When True (nonzero), yield x, otherwise yield y.
:type condition: Block of triton.bool
:param x: values selected at indices where condition is True.
:param y: values selected at indices where condition is False.
"""
return frontend.where(condition, x, y, _builder)
# -----------------------
# Math
# -----------------------
@builtin
def umulhi(x, y, _builder=None):
return frontend.umulhi(x, y, _builder)
def _add_math_1arg_docstr(name):
def _decorator(func):
docstr = """
Computes the element-wise {name} of :code:`x`
:param x: the input values
:type x: Block
"""
func.__doc__ = docstr.format(name=name)
return func
return _decorator
@builtin
@_add_math_1arg_docstr("exponential")
def exp(x, _builder=None):
return frontend.exp(x, _builder)
@builtin
@_add_math_1arg_docstr("natural logarithm")
def log(x, _builder=None):
return frontend.log(x, _builder)
@builtin
@_add_math_1arg_docstr("cosine")
def cos(x, _builder=None):
return frontend.cos(x, _builder)
@builtin
@_add_math_1arg_docstr("sine")
def sin(x, _builder=None):
return frontend.sin(x, _builder)
@builtin
@_add_math_1arg_docstr("square root")
def sqrt(x, _builder=None):
return frontend.sqrt(x, _builder)
# -----------------------
# Reductions
# -----------------------
def _add_reduction_docstr(name):
def _decorator(func):
docstr = """
Returns the {name} of all elements in the :code:`input` block along the provided :code:`axis`
:param input: the input values
:param axis: the dimension along which the reduction should be done
"""
func.__doc__ = docstr.format(name=name)
return func
return _decorator
@builtin
@_add_reduction_docstr("maximum")
def max(input, axis, _builder=None):
return frontend.max(input, axis, _builder)
@builtin
@_add_reduction_docstr("minimum")
def min(input, axis, _builder=None):
return frontend.min(input, axis, _builder)
@builtin
@_add_reduction_docstr("sum")
def sum(input, axis, _builder=None):
return frontend.sum(input, axis, _builder)
# -----------------------
# Internal for debugging
# -----------------------
@builtin
def debug_barrier(_builder=None):
return frontend.debug_barrier(_builder)
@builtin
def multiple_of(input, value, _builder=None):
"""
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
"""
return frontend.multiple_of(input, value, _builder)
@builtin
def max_contiguous(input, value, _builder=None):
"""
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
"""
return frontend.max_contiguous(input, value, _builder)
@builtin
def max_contiguous(input, value, _builder=None):
"""
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
"""
return frontend.max_contiguous(input, value, _builder)
# -----------------------
# Standard library
# -----------------------
@triton.jit
def abs(x):
return where(x >= 0, x, -x)
@triton.jit
def cdiv(x, div):
"""
Computes the ceiling division of :code:`x` by :code:`div`
:param x: the input number
:type input: Block
:param div: the divisor
:param div: Block
"""
return (x + div - 1) // div
@triton.jit
def minimum(x, y):
"""
Computes the element-wise minimum of :code:`x` and :code:`y`.
:param input: the first input block
:type input: Block
:param other: the second input block
:type other: Block
"""
return triton.language.where(x < y, x, y)
@triton.jit
def maximum(x, y):
"""
Computes the element-wise maximum of :code:`x` and :code:`y`.
:param input: the first input block
:type input: Block
:param other: the second input block
:type other: Block
"""
return triton.language.where(x > y, x, y)
@triton.jit
@_add_math_1arg_docstr("sigmoid")
def sigmoid(x):
return 1 / (1 + triton.language.exp(-x))
@triton.jit
@_add_math_1arg_docstr("softmax")
def softmax(x):
z = x - triton.language.max(x, 0)
num = triton.language.exp(z)
den = triton.language.sum(num, 0)
return num / den
@triton.jit
def ravel(x):
"""
Returns a contiguous flattened view of :code:`x`
:param x: the input block
:type x: Block
"""
return triton.language.reshape(x, [x.type.numel])
@triton.jit
def swizzle2d(i, j, size_i, size_j, size_g):
"""
transformes indices of a row-major size_i*size_j matrix into those
of one where indices are row major for each group of size_j rows.
For example, for size_i = size_j = 4 and size_g = 2, it will transform
[[0 , 1 , 2 , 3 ],
[4 , 5 , 6 , 7 ],
[8 , 9 , 10, 11],
[12, 13, 14, 15]]
into
[[0, 2, 4 , 6 ],
[1, 3, 5 , 7 ],
[8, 10, 12, 14],
[9, 11, 13, 15]]
"""
# "unrolled index in array"
ij = i*size_j + j
# number of elements in `size_g` groups
# of `size_j` columns
size_gj = size_g * size_j
# index of the group in which (i,j) is
group_id = ij // size_gj
# row-index of the first element of this group
off_i = group_id * size_g
# last group may have fewer rows
size_g = minimum(size_i - off_i, size_g)
# new row and column indices
new_i = off_i + (ij % size_g)
new_j = (ij % size_gj) // size_g
return new_i, new_j