- Fix some whitespace. - Make an undeclared dependency on `pytest` explicit. - Fix deprecated `description-file` use. - `#ifdef` out a deprecated `PyEval_InitThreads` call. - Use a slightly different numpy invocation in `test_random.py` to quiet down overflow warnings in tests. - Fix a deprecated cast in `test_core.py`. - Suppress a warning about `visit_Constant` in Python 3.9+; we can't migrate yet because it'd break Python 3.6 and 3.7. - Use chained exceptions for `CompilationError` rather than rolling our own; it makes the error messages nicer. - Add a `__str__` for `tl.dtype` to make debugging kernels easier; it lets you `print` a dtype to see what type was inferred. - Fix a few bad escapes.
860 lines
23 KiB
Python
860 lines
23 KiB
Python
import triton
|
|
from triton._C.libtriton.triton import ir
|
|
from triton._C.libtriton.triton import frontend
|
|
from functools import wraps
|
|
|
|
|
|
# convert block/dtype to ir values
|
|
def _to_ir(x, builder):
|
|
if isinstance(x, bool):
|
|
return builder.get_int1(x)
|
|
elif isinstance(x, int):
|
|
if x.__abs__() <= 2**31:
|
|
return builder.get_int32(x)
|
|
return builder.get_int64(x)
|
|
elif isinstance(x, float):
|
|
return builder.get_float32(x)
|
|
elif isinstance(x, constexpr):
|
|
return _to_ir(x.value, builder)
|
|
elif isinstance(x, block):
|
|
return x.handle
|
|
elif isinstance(x, dtype):
|
|
return x.handle(builder)
|
|
return x
|
|
|
|
|
|
def _patch(fn):
|
|
def _from_ir(x):
|
|
if isinstance(x, ir.value):
|
|
if x.type.is_void():
|
|
return None
|
|
return block(x)
|
|
return x
|
|
|
|
def wrapper(*args, **kwargs):
|
|
builder = args[-1]
|
|
assert isinstance(builder, ir.builder)
|
|
args = [_to_ir(x, builder) for x in args]
|
|
# for i, arg in enumerate(args):
|
|
# if arg is None:
|
|
# raise ValueError(f"Unexpected `None` at position {i} for function {fn.__name__}")
|
|
kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()}
|
|
ret = fn(*args, **kwargs)
|
|
if isinstance(ret, tuple):
|
|
return map(_from_ir, ret)
|
|
return _from_ir(ret)
|
|
|
|
return wrapper
|
|
|
|
|
|
for name in dir(frontend):
|
|
fn = getattr(frontend, name)
|
|
if callable(fn):
|
|
setattr(frontend, name, _patch(fn))
|
|
|
|
|
|
def builtin(fn):
|
|
@wraps(fn)
|
|
def wrapper(*args, **kwargs):
|
|
if '_builder' not in kwargs or \
|
|
kwargs['_builder'] is None:
|
|
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
|
|
return fn(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
class dtype:
|
|
def __init__(self, init):
|
|
self.init = init
|
|
|
|
def handle(self, builder):
|
|
ctx = builder.context
|
|
return self.init(ctx)
|
|
|
|
def __str__(self):
|
|
return f"dtype({self.init.__name__})"
|
|
|
|
|
|
class pointer_dtype:
|
|
def __init__(self, element_ty):
|
|
self.element_ty = element_ty
|
|
|
|
def handle(self, builder):
|
|
return ir.type.make_ptr(self.element_ty.handle(builder), 1)
|
|
|
|
# scalar types
|
|
int1 = dtype(ir.type.get_int1)
|
|
int8 = dtype(ir.type.get_int8)
|
|
int16 = dtype(ir.type.get_int16)
|
|
int32 = dtype(ir.type.get_int32)
|
|
int64 = dtype(ir.type.get_int64)
|
|
float8 = dtype(ir.type.get_fp8)
|
|
float16 = dtype(ir.type.get_fp16)
|
|
bfloat16 = dtype(ir.type.get_bf16)
|
|
float32 = dtype(ir.type.get_fp32)
|
|
float64 = dtype(ir.type.get_fp64)
|
|
# pointer types
|
|
pi32_t = pointer_dtype(int32)
|
|
|
|
|
|
class block:
|
|
@staticmethod
|
|
def _init_dtype(ir_type):
|
|
# primitive type
|
|
if ir_type.is_int1(): return int1
|
|
if ir_type.is_int8(): return int8
|
|
if ir_type.is_int16(): return int16
|
|
if ir_type.is_int32(): return int32
|
|
if ir_type.is_int64(): return int64
|
|
if ir_type.is_fp8(): return float8
|
|
if ir_type.is_fp16(): return float16
|
|
if ir_type.is_bf16(): return bfloat16
|
|
if ir_type.is_fp32(): return float32
|
|
if ir_type.is_fp64(): return float64
|
|
# pointer type
|
|
if ir_type.is_ptr():
|
|
element_ty = block._init_dtype(ir_type.element)
|
|
return pointer_dtype(element_ty)
|
|
raise ValueError(f"Unsupported type {ir_type}")
|
|
|
|
def __init__(self, handle):
|
|
# IR handle
|
|
self.handle = handle
|
|
# Block shape
|
|
self.shape = (1, )
|
|
if self.handle.type.is_block():
|
|
self.shape = self.handle.type.shape
|
|
self.numel = 1
|
|
for s in self.shape:
|
|
self.numel *= s
|
|
# Data-type wrapper
|
|
self.dtype = block._init_dtype(self.handle.type.scalar)
|
|
|
|
@builtin
|
|
def __add__(self, other, _builder=None):
|
|
return frontend.add(self, other, _builder)
|
|
|
|
def __radd__(self, other, _builder=None):
|
|
return self.__add__(other, _builder=_builder)
|
|
|
|
@builtin
|
|
def __sub__(self, other, _builder=None):
|
|
return frontend.sub(self, other, _builder)
|
|
|
|
def __rsub__(self, other, _builder=None):
|
|
return frontend.sub(other, self, _builder)
|
|
|
|
@builtin
|
|
def __mul__(self, other, _builder=None):
|
|
return frontend.mul(self, other, _builder)
|
|
|
|
def __rmul__(self, other, _builder=None):
|
|
return self.__mul__(other, _builder=_builder)
|
|
|
|
@builtin
|
|
def __truediv__(self, other, _builder=None):
|
|
return frontend.truediv(self, other, _builder)
|
|
|
|
def __rtruediv__(self, other, _builder=None):
|
|
return frontend.truediv(other, self, _builder)
|
|
|
|
@builtin
|
|
def __floordiv__(self, other, _builder=None):
|
|
return frontend.floordiv(self, other, _builder)
|
|
|
|
@builtin
|
|
def __mod__(self, other, _builder=None):
|
|
return frontend.mod(self, other, _builder)
|
|
|
|
# unary operators
|
|
@builtin
|
|
def __neg__(self, _builder=None):
|
|
return frontend.minus(self, _builder)
|
|
|
|
@builtin
|
|
def __invert__(self, _builder=None):
|
|
return frontend.invert(self, _builder)
|
|
|
|
# bitwise operators
|
|
|
|
@builtin
|
|
def __and__(self, other, _builder=None):
|
|
return frontend.and_(self, other, _builder)
|
|
|
|
@builtin
|
|
def __or__(self, other, _builder=None):
|
|
return frontend.or_(self, other, _builder)
|
|
|
|
@builtin
|
|
def __xor__(self, other, _builder=None):
|
|
return frontend.xor_(self, other, _builder)
|
|
|
|
@builtin
|
|
def __lshift__(self, other, _builder=None):
|
|
return frontend.shl(self, other, _builder)
|
|
|
|
@builtin
|
|
def __rshift__(self, other, _builder=None):
|
|
return frontend.lshr(self, other, _builder)
|
|
|
|
# comparison operators
|
|
|
|
# >
|
|
@builtin
|
|
def __gt__(self, other, _builder=None):
|
|
return frontend.greater_than(self, other, _builder)
|
|
|
|
@builtin
|
|
def __rgt__(self, other, _builder=None):
|
|
return frontend.greater_than(other, self, _builder)
|
|
|
|
# >=
|
|
@builtin
|
|
def __ge__(self, other, _builder=None):
|
|
return frontend.greater_equal(self, other, _builder)
|
|
|
|
def __rge__(self, other, _builder=None):
|
|
return frontend.greater_equal(other, self, _builder)
|
|
|
|
# <
|
|
@builtin
|
|
def __lt__(self, other, _builder=None):
|
|
return frontend.less_than(self, other, _builder)
|
|
|
|
@builtin
|
|
def __rlt__(self, other, _builder=None):
|
|
return frontend.less_than(other, self, _builder)
|
|
|
|
# <=
|
|
@builtin
|
|
def __le__(self, other, _builder=None):
|
|
return frontend.less_equal(self, other, _builder)
|
|
|
|
@builtin
|
|
def __rle__(self, other, _builder=None):
|
|
return frontend.less_equal(other, self, _builder)
|
|
|
|
# ==
|
|
@builtin
|
|
def __eq__(self, other, _builder=None):
|
|
return frontend.equal(self, other, _builder)
|
|
|
|
@builtin
|
|
def __ne__(self, other, _builder=None):
|
|
return frontend.not_equal(self, other, _builder)
|
|
|
|
@builtin
|
|
def __getitem__(self, slices, _builder=None):
|
|
if isinstance(slices, slice):
|
|
slices = [slices]
|
|
src_shape = self.shape
|
|
dst_shape = []
|
|
curr = 0
|
|
for sl in slices:
|
|
if sl == None:
|
|
dst_shape.append(1)
|
|
elif sl == slice(None, None, None):
|
|
dst_shape.append(src_shape[curr])
|
|
curr += 1
|
|
ret = frontend.reshape(self, dst_shape, _builder)
|
|
return ret
|
|
|
|
@builtin
|
|
def to(self, dtype, bitcast=False, _builder=None):
|
|
dtype = dtype.handle(_builder)
|
|
if bitcast:
|
|
return frontend.bitcast(self, dtype, _builder)
|
|
return frontend.cast(self, dtype, _builder)
|
|
|
|
|
|
# -----------------------
|
|
# constexpr
|
|
# -----------------------
|
|
|
|
class constexpr:
|
|
"""
|
|
This class is used to store a value that is known at compile-time.
|
|
"""
|
|
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def __add__(self, other):
|
|
return self.value + other.value
|
|
|
|
def __radd__(self, other):
|
|
return other.value + self.value
|
|
|
|
def __sub__(self, other):
|
|
return self.value - other.value
|
|
|
|
def __rsub__(self, other):
|
|
return other.value - self.value
|
|
|
|
def __mul__(self, other):
|
|
return self.value * other.value
|
|
|
|
def __rmul__(self, other):
|
|
return other.value * self.value
|
|
|
|
def __truediv__(self, other):
|
|
return self.value / other.value
|
|
|
|
def __rtruediv__(self, other):
|
|
return other.value / self.value
|
|
|
|
def __floordiv__(self, other):
|
|
return self.value // other.value
|
|
|
|
def __rfloordiv__(self, other):
|
|
return other.value // self.value
|
|
|
|
#
|
|
|
|
def __gt__(self, other):
|
|
return self.value > other.value
|
|
|
|
def __rgt__(self, other):
|
|
return other.value > self.value
|
|
|
|
def __ge__(self, other):
|
|
return self.value >= other.value
|
|
|
|
def __rge__(self, other):
|
|
return other.value >= self.value
|
|
|
|
def __lt__(self, other):
|
|
return self.value < other.value
|
|
|
|
def __rlt__(self, other):
|
|
return other.value < self.value
|
|
|
|
def __le__(self, other):
|
|
return self.value <= other.value
|
|
|
|
def __rle__(self, other):
|
|
return other.value <= self.value
|
|
|
|
def __eq__(self, other):
|
|
return self.value == other.value
|
|
|
|
def __ne__(self, other):
|
|
return self.value != other.value
|
|
|
|
def __bool__(self):
|
|
return bool(self.value)
|
|
|
|
def __call__(self, *args, **kwds):
|
|
return self.value(*args, **kwds)
|
|
|
|
# -----------------------
|
|
# SPMD Programming Model
|
|
# -----------------------
|
|
|
|
|
|
@builtin
|
|
def program_id(axis, _builder=None):
|
|
"""
|
|
Returns the id of the current program instance along the given :code:`axis`.
|
|
|
|
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
|
:type axis: int
|
|
"""
|
|
# if axis == -1:
|
|
# pid0 = frontend.program_id(0, _builder)
|
|
# pid1 = frontend.program_id(1, _builder)
|
|
# pid2 = frontend.program_id(2, _builder)
|
|
# npg0 = frontend.num_programs(0, _builder)
|
|
# npg1 = frontend.num_programs(0, _builder)
|
|
# return pid0 + pid1*npg0 + pid2*npg0*npg1
|
|
return frontend.program_id(axis, _builder)
|
|
|
|
|
|
@builtin
|
|
def num_programs(axis, _builder=None):
|
|
"""
|
|
Returns the number of program instances launched along the given :code:`axis`.
|
|
|
|
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
|
:type axis: int
|
|
"""
|
|
return frontend.num_programs(axis, _builder)
|
|
|
|
|
|
# -----------------------
|
|
# Block Initialization
|
|
# -----------------------
|
|
|
|
|
|
@builtin
|
|
def arange(start, end, _builder=None):
|
|
"""
|
|
Returns contiguous values within the open interval [:code:`start`, :code:`end`).
|
|
|
|
:param start: Start of the interval. Must be a power of two.
|
|
:type start: int
|
|
:param stop: End of the interval. Must be a power of two >= start.
|
|
:type stop: int
|
|
"""
|
|
return frontend.arange(start, end, _builder)
|
|
|
|
|
|
@builtin
|
|
def zeros(shape, dtype, _builder=None):
|
|
"""
|
|
Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
|
|
|
|
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
|
:type shape: tuple of ints
|
|
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
|
|
:type dtype: DType
|
|
"""
|
|
for i, d in enumerate(shape):
|
|
if not isinstance(d, constexpr):
|
|
raise TypeError(f"Shape element {i} must have type `constexpr`")
|
|
if not isinstance(d.value, int):
|
|
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
|
shape = [x.value for x in shape]
|
|
return frontend.zeros(shape, dtype, _builder)
|
|
|
|
|
|
# -----------------------
|
|
# Shape Manipulation
|
|
# -----------------------
|
|
|
|
|
|
@builtin
|
|
def broadcast(input, other, _builder=None):
|
|
"""
|
|
Tries to broadcast the two given blocks to a common compatible shape.
|
|
|
|
:param input: The first input block.
|
|
:type input: Block
|
|
:param other: The second input block.
|
|
:type other: Block
|
|
"""
|
|
return frontend.broadcast(input, other, _builder)
|
|
|
|
|
|
@builtin
|
|
def broadcast_to(input, shape, _builder=None):
|
|
"""
|
|
Tries to broadcast the given block to a new :code:`shape`.
|
|
|
|
:param input: The input block.
|
|
:type input: Block
|
|
:param shape: The desired shape.
|
|
:type shape: Tuple[int]
|
|
"""
|
|
return frontend.broadcast_to(input, shape, _builder)
|
|
|
|
@builtin
|
|
def cat(input, other, _builder=None):
|
|
"""
|
|
Concatenate the given blocks
|
|
|
|
:param input: The first input block.
|
|
:type input:
|
|
:param other: The second input block.
|
|
:type other:
|
|
"""
|
|
return frontend.cat(input, other, _builder)
|
|
|
|
|
|
@builtin
|
|
def reshape(input, shape, _builder=None):
|
|
"""
|
|
Tries to reshape the given block to a new shape.
|
|
|
|
:param input: The input block.
|
|
:type input:
|
|
:param shape: The desired shape.
|
|
:type shape: Tuple[int]
|
|
|
|
"""
|
|
return frontend.reshape(input, shape, _builder)
|
|
|
|
|
|
# -----------------------
|
|
# Linear Algebra
|
|
# -----------------------
|
|
|
|
|
|
@builtin
|
|
def dot(input, other, _builder=None):
|
|
"""
|
|
Returns the matrix product of two blocks.
|
|
|
|
The two blocks must be two dimensionals and have compatible inner dimensions.
|
|
|
|
:param input: The first block to be multiplied.
|
|
:type input: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
|
|
:param other: The second block to be multiplied.
|
|
:type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
|
|
"""
|
|
return frontend.dot(input, other, _builder)
|
|
|
|
|
|
# -----------------------
|
|
# Non-Atomic Memory Operations
|
|
# -----------------------
|
|
|
|
|
|
@builtin
|
|
def load(pointer, mask=None, other=None, cache_modifier="", _builder=None):
|
|
"""
|
|
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
|
|
|
|
:code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`.
|
|
|
|
:code:`other` is implicitly typecast to :code:`pointer.dtype.element_ty`.
|
|
|
|
:param pointer: Pointers to the data to be loaded.
|
|
:type pointer: Block of dtype=triton.PointerDType
|
|
:param mask: if mask[idx] is false, do not load the data at address :code:`pointer[idx]`.
|
|
:type mask: Block of triton.int1, optional
|
|
:param other: if mask[idx] is false, return other[idx]
|
|
:type other: Block, optional
|
|
:param cache_modifier: changes cache option in nvidia ptx
|
|
'type cache_modifier: str, optional
|
|
"""
|
|
return frontend.load(pointer, mask, other, cache_modifier, _builder)
|
|
|
|
|
|
@builtin
|
|
def store(pointer, value, mask=None, _builder=None):
|
|
"""
|
|
Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
|
|
|
|
:code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`.
|
|
|
|
:param pointer: The memory locations where the elements of :code:`value` are stored.
|
|
:type pointer: Block of dtype=triton.PointerDType
|
|
:param value: The block of elements to be stored.
|
|
:type value: Block
|
|
:param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`.
|
|
:type mask: Block of triton.int1, optional
|
|
"""
|
|
return frontend.store(pointer, value, mask, _builder)
|
|
|
|
|
|
# -----------------------
|
|
# Atomic Memory Operations
|
|
# -----------------------
|
|
|
|
def _add_atomic_docstr(name):
|
|
|
|
def _decorator(func):
|
|
docstr = """
|
|
Performs an atomic {name} at the memory location specified by :code:`pointer`.
|
|
|
|
Return the data stored at :code:`pointer` before the atomic operation.
|
|
|
|
:param pointer: The memory locations to compare-and-swap.
|
|
:type pointer: Block of dtype=triton.PointerDType
|
|
:param cmp: The values expected to be found in the atomic object
|
|
:type cmp: Block of dtype=`pointer.dtype.element_ty`
|
|
:param val: The values to copy in case the expected value matches the contained value.
|
|
:type val: Block of dtype=`pointer.dtype.element_ty`
|
|
"""
|
|
func.__doc__ = docstr.format(name=name)
|
|
return func
|
|
|
|
return _decorator
|
|
|
|
@builtin
|
|
@_add_atomic_docstr("compare-and-swap")
|
|
def atomic_cas(pointer, cmp, val, _builder=None):
|
|
return frontend.atomic_cas(pointer, cmp, val, _builder)
|
|
|
|
|
|
@builtin
|
|
@_add_atomic_docstr("exchange")
|
|
def atomic_xchg(pointer, val, mask=None, _builder=None):
|
|
return frontend.atomic_xchg(pointer, val, mask, _builder)
|
|
|
|
@builtin
|
|
@_add_atomic_docstr("add")
|
|
def atomic_add(pointer, val, mask=None, _builder=None):
|
|
return frontend.atomic_add(pointer, val, mask, _builder)
|
|
|
|
|
|
@builtin
|
|
@_add_atomic_docstr("max")
|
|
def atomic_max(pointer, val, mask=None, _builder=None):
|
|
return frontend.atomic_max(pointer, val, mask, _builder)
|
|
|
|
|
|
@builtin
|
|
@_add_atomic_docstr("min")
|
|
def atomic_min(pointer, val, mask=None, _builder=None):
|
|
return frontend.atomic_min(pointer, val, mask, _builder)
|
|
|
|
|
|
@builtin
|
|
@_add_atomic_docstr("logical and")
|
|
def atomic_and(pointer, val, mask=None, _builder=None):
|
|
return frontend.atomic_and(pointer, val, mask, _builder)
|
|
|
|
|
|
@builtin
|
|
@_add_atomic_docstr("logical or")
|
|
def atomic_or(pointer, val, mask=None, _builder=None):
|
|
return frontend.atomic_or(pointer, val, mask, _builder)
|
|
|
|
|
|
@builtin
|
|
@_add_atomic_docstr("logical xor")
|
|
def atomic_xor(pointer, val, mask=None, _builder=None):
|
|
return frontend.atomic_xor(pointer, val, mask, _builder)
|
|
|
|
|
|
# -----------------------
|
|
# Conditioning
|
|
# -----------------------
|
|
|
|
|
|
@builtin
|
|
def where(condition, x, y, _builder=None):
|
|
"""
|
|
Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
|
|
|
|
Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`.
|
|
|
|
If you want to avoid unintented memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead.
|
|
|
|
The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`.
|
|
:code:`x` and :code:`y` must have the data type.
|
|
|
|
:param condition: When True (nonzero), yield x, otherwise yield y.
|
|
:type condition: Block of triton.bool
|
|
:param x: values selected at indices where condition is True.
|
|
:param y: values selected at indices where condition is False.
|
|
"""
|
|
return frontend.where(condition, x, y, _builder)
|
|
|
|
|
|
# -----------------------
|
|
# Math
|
|
# -----------------------
|
|
|
|
@builtin
|
|
def umulhi(x, y, _builder=None):
|
|
return frontend.umulhi(x, y, _builder)
|
|
|
|
def _add_math_1arg_docstr(name):
|
|
|
|
def _decorator(func):
|
|
docstr = """
|
|
Computes the element-wise {name} of :code:`x`
|
|
|
|
:param x: the input values
|
|
:type x: Block
|
|
"""
|
|
func.__doc__ = docstr.format(name=name)
|
|
return func
|
|
|
|
return _decorator
|
|
|
|
@builtin
|
|
@_add_math_1arg_docstr("exponential")
|
|
def exp(x, _builder=None):
|
|
return frontend.exp(x, _builder)
|
|
|
|
@builtin
|
|
@_add_math_1arg_docstr("natural logarithm")
|
|
def log(x, _builder=None):
|
|
return frontend.log(x, _builder)
|
|
|
|
@builtin
|
|
@_add_math_1arg_docstr("cosine")
|
|
def cos(x, _builder=None):
|
|
return frontend.cos(x, _builder)
|
|
|
|
@builtin
|
|
@_add_math_1arg_docstr("sine")
|
|
def sin(x, _builder=None):
|
|
return frontend.sin(x, _builder)
|
|
|
|
|
|
@builtin
|
|
@_add_math_1arg_docstr("square root")
|
|
def sqrt(x, _builder=None):
|
|
return frontend.sqrt(x, _builder)
|
|
|
|
|
|
# -----------------------
|
|
# Reductions
|
|
# -----------------------
|
|
|
|
def _add_reduction_docstr(name):
|
|
|
|
def _decorator(func):
|
|
docstr = """
|
|
Returns the {name} of all elements in the :code:`input` block along the provided :code:`axis`
|
|
|
|
:param input: the input values
|
|
:param axis: the dimension along which the reduction should be done
|
|
"""
|
|
func.__doc__ = docstr.format(name=name)
|
|
return func
|
|
|
|
return _decorator
|
|
|
|
@builtin
|
|
@_add_reduction_docstr("maximum")
|
|
def max(input, axis, _builder=None):
|
|
return frontend.max(input, axis, _builder)
|
|
|
|
|
|
@builtin
|
|
@_add_reduction_docstr("minimum")
|
|
def min(input, axis, _builder=None):
|
|
return frontend.min(input, axis, _builder)
|
|
|
|
|
|
@builtin
|
|
@_add_reduction_docstr("sum")
|
|
def sum(input, axis, _builder=None):
|
|
return frontend.sum(input, axis, _builder)
|
|
|
|
|
|
# -----------------------
|
|
# Internal for debugging
|
|
# -----------------------
|
|
|
|
|
|
@builtin
|
|
def debug_barrier(_builder=None):
|
|
return frontend.debug_barrier(_builder)
|
|
|
|
|
|
@builtin
|
|
def multiple_of(input, value, _builder=None):
|
|
"""
|
|
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
|
|
"""
|
|
return frontend.multiple_of(input, value, _builder)
|
|
|
|
|
|
@builtin
|
|
def max_contiguous(input, value, _builder=None):
|
|
"""
|
|
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
|
|
"""
|
|
return frontend.max_contiguous(input, value, _builder)
|
|
|
|
|
|
@builtin
|
|
def max_contiguous(input, value, _builder=None):
|
|
"""
|
|
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
|
|
"""
|
|
return frontend.max_contiguous(input, value, _builder)
|
|
|
|
|
|
# -----------------------
|
|
# Standard library
|
|
# -----------------------
|
|
|
|
@triton.jit
|
|
def abs(x):
|
|
return where(x >= 0, x, -x)
|
|
|
|
@triton.jit
|
|
def cdiv(x, div):
|
|
"""
|
|
Computes the ceiling division of :code:`x` by :code:`div`
|
|
|
|
:param x: the input number
|
|
:type input: Block
|
|
:param div: the divisor
|
|
:param div: Block
|
|
"""
|
|
return (x + div - 1) // div
|
|
|
|
|
|
@triton.jit
|
|
def minimum(x, y):
|
|
"""
|
|
Computes the element-wise minimum of :code:`x` and :code:`y`.
|
|
|
|
:param input: the first input block
|
|
:type input: Block
|
|
:param other: the second input block
|
|
:type other: Block
|
|
"""
|
|
return triton.language.where(x < y, x, y)
|
|
|
|
|
|
@triton.jit
|
|
def maximum(x, y):
|
|
"""
|
|
Computes the element-wise maximum of :code:`x` and :code:`y`.
|
|
|
|
:param input: the first input block
|
|
:type input: Block
|
|
:param other: the second input block
|
|
:type other: Block
|
|
"""
|
|
return triton.language.where(x > y, x, y)
|
|
|
|
|
|
@triton.jit
|
|
@_add_math_1arg_docstr("sigmoid")
|
|
def sigmoid(x):
|
|
return 1 / (1 + triton.language.exp(-x))
|
|
|
|
|
|
@triton.jit
|
|
@_add_math_1arg_docstr("softmax")
|
|
def softmax(x):
|
|
z = x - triton.language.max(x, 0)
|
|
num = triton.language.exp(z)
|
|
den = triton.language.sum(num, 0)
|
|
return num / den
|
|
|
|
|
|
@triton.jit
|
|
def ravel(x):
|
|
"""
|
|
Returns a contiguous flattened view of :code:`x`
|
|
|
|
:param x: the input block
|
|
:type x: Block
|
|
"""
|
|
return triton.language.reshape(x, [x.type.numel])
|
|
|
|
@triton.jit
|
|
def swizzle2d(i, j, size_i, size_j, size_g):
|
|
"""
|
|
transformes indices of a row-major size_i*size_j matrix into those
|
|
of one where indices are row major for each group of size_j rows.
|
|
For example, for size_i = size_j = 4 and size_g = 2, it will transform
|
|
[[0 , 1 , 2 , 3 ],
|
|
[4 , 5 , 6 , 7 ],
|
|
[8 , 9 , 10, 11],
|
|
[12, 13, 14, 15]]
|
|
into
|
|
[[0, 2, 4 , 6 ],
|
|
[1, 3, 5 , 7 ],
|
|
[8, 10, 12, 14],
|
|
[9, 11, 13, 15]]
|
|
"""
|
|
# "unrolled index in array"
|
|
ij = i*size_j + j
|
|
# number of elements in `size_g` groups
|
|
# of `size_j` columns
|
|
size_gj = size_g * size_j
|
|
# index of the group in which (i,j) is
|
|
group_id = ij // size_gj
|
|
# row-index of the first element of this group
|
|
off_i = group_id * size_g
|
|
# last group may have fewer rows
|
|
size_g = minimum(size_i - off_i, size_g)
|
|
# new row and column indices
|
|
new_i = off_i + (ij % size_g)
|
|
new_j = (ij % size_gj) // size_g
|
|
return new_i, new_j
|