[DOCS] Various improvements (#224)

- Added docstr for autotune, Config, heuristics
- Added docstr for atomics
- Hiding internal _builder argument used for built-in language primitives
- Re-factor docstr to use common templates between similar functions.
This commit is contained in:
Philippe Tillet
2021-08-18 11:15:53 -07:00
committed by GitHub
parent 226fde6ea1
commit f26a48a3b4
6 changed files with 275 additions and 222 deletions

View File

@@ -24,11 +24,20 @@
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
def process_sig(app, what, name, obj, options, signature, return_annotation):
if signature and '_builder' in signature:
signature = signature.split('_builder')[0] + ")"
return (signature, return_annotation)
def setup(app): def setup(app):
"""Customize function args retrieving to get args under decorator.""" """Customize function args retrieving to get args under decorator."""
import sphinx import sphinx
import triton import triton
app.connect("autodoc-process-signature", process_sig)
def forward_jit_fn(func): def forward_jit_fn(func):
old = func old = func
@@ -39,6 +48,7 @@ def setup(app):
return wrapped return wrapped
old_documenter = sphinx.ext.autosummary.get_documenter old_documenter = sphinx.ext.autosummary.get_documenter
def documenter(app, obj, parent): def documenter(app, obj, parent):

View File

@@ -98,6 +98,18 @@ Reduction Ops
min min
sum sum
Atomic Ops
---------------
.. autosummary::
:toctree: generated
:nosignatures:
atomic_cas
atomic_add
atomic_max
atomic_min
Comparison ops Comparison ops
--------------- ---------------

View File

@@ -8,3 +8,6 @@ triton
:nosignatures: :nosignatures:
jit jit
autotune
heuristics
Config

View File

@@ -193,11 +193,11 @@ class CodeGenerator(ast.NodeVisitor):
kws = dict() kws = dict()
if self.is_triton_object(lhs): if self.is_triton_object(lhs):
kws['builder'] = self.builder kws['_builder'] = self.builder
ret = getattr(lhs, fn)(rhs, **kws) ret = getattr(lhs, fn)(rhs, **kws)
if ret is NotImplemented: if ret is NotImplemented:
if self.is_triton_object(rhs): if self.is_triton_object(rhs):
kws['builder'] = self.builder kws['_builder'] = self.builder
fn = fn[:2] + 'r' + fn[2:] fn = fn[:2] + 'r' + fn[2:]
ret = getattr(rhs, fn)(lhs, **kws) ret = getattr(rhs, fn)(lhs, **kws)
return ret return ret
@@ -260,10 +260,10 @@ class CodeGenerator(ast.NodeVisitor):
ast.IsNot: '__ne__', ast.IsNot: '__ne__',
}[type(node.ops[0])] }[type(node.ops[0])]
if self.is_triton_object(lhs): if self.is_triton_object(lhs):
return getattr(lhs, fn)(rhs, builder=self.builder) return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_object(rhs): elif self.is_triton_object(rhs):
fn = fn[:2] + 'r' + fn[2:] fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, builder=self.builder) return getattr(rhs, fn)(lhs, _builder=self.builder)
else: else:
return getattr(lhs, fn)(rhs) return getattr(lhs, fn)(rhs)
@@ -275,7 +275,7 @@ class CodeGenerator(ast.NodeVisitor):
ast.Invert: '__invert__', ast.Invert: '__invert__',
}[type(node.op)] }[type(node.op)]
if self.is_triton_object(op): if self.is_triton_object(op):
return getattr(op, fn)(builder=self.builder) return getattr(op, fn)(_builder=self.builder)
return getattr(op, fn)() return getattr(op, fn)()
def visit_While(self, node): def visit_While(self, node):
@@ -308,7 +308,7 @@ class CodeGenerator(ast.NodeVisitor):
lhs = self.visit(node.value) lhs = self.visit(node.value)
slices = self.visit(node.slice) slices = self.visit(node.slice)
if self.is_triton_object(lhs): if self.is_triton_object(lhs):
return lhs.__getitem__(slices, builder=self.builder) return lhs.__getitem__(slices, _builder=self.builder)
return lhs[slices] return lhs[slices]
def visit_ExtSlice(self, node): def visit_ExtSlice(self, node):
@@ -331,7 +331,7 @@ class CodeGenerator(ast.NodeVisitor):
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\ build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
self.visit(pos_cond_node),\ self.visit(pos_cond_node),\
self.visit(neg_cond_node),\ self.visit(neg_cond_node),\
builder=self.builder) _builder=self.builder)
#cond_node = neg_cond_node #cond_node = neg_cond_node
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
# code generation # code generation
@@ -385,7 +385,7 @@ class CodeGenerator(ast.NodeVisitor):
return fn(*args, generator=self, **kws) return fn(*args, generator=self, **kws)
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
sys.modules[fn.__module__] is triton.language: sys.modules[fn.__module__] is triton.language:
return fn(*args, builder=self.builder, **kws) return fn(*args, _builder=self.builder, **kws)
return fn(*args, **kws) return fn(*args, **kws)
def visit_Num(self, node): def visit_Num(self, node):
@@ -714,6 +714,19 @@ class JITFunction:
class Config: class Config:
"""
An object that represents a possible kernel configuration for the auto-tuner to try.
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
:type meta: dict[Str, Any]
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
`num_warps=8`, then each kernel instance will be automatically parallelized to
cooperatively execute using `8 * 32 = 256` threads.
:type num_warps: int
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
:type num_stages: int
"""
def __init__(self, meta, num_warps=4, num_stages=2): def __init__(self, meta, num_warps=4, num_stages=2):
self.meta = meta self.meta = meta
self.num_warps = num_warps self.num_warps = num_warps
@@ -721,6 +734,35 @@ class Config:
def autotune(configs, key, reset_to_zero=None): def autotune(configs, key, reset_to_zero=None):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn): def decorator(fn):
def wrapper(kernel): def wrapper(kernel):
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero) return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero)
@@ -732,6 +774,23 @@ def autotune(configs, key, reset_to_zero=None):
def heuristics(values): def heuristics(values):
"""
Decorator for specifying how the values of certain meta-parameters may be computed.
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
.. highlight:: python
.. code-block:: python
@heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
each such function takes a list of positional arguments as input.
.type values: dict[str, Callable[[list[Any]], Any]]
"""
def decorator(fn): def decorator(fn):
def wrapper(kernel): def wrapper(kernel):
def fun(*args, **meta): def fun(*args, **meta):
@@ -767,6 +826,8 @@ def jit(fn):
return JITFunction(fn) return JITFunction(fn)
######
def cdiv(x, y): def cdiv(x, y):
return (x + y - 1) // y return (x + y - 1) // y

View File

@@ -49,16 +49,11 @@ for name in dir(frontend):
def builtin(fn): def builtin(fn):
@wraps(fn) @wraps(fn)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if 'builder' not in kwargs or \ if '_builder' not in kwargs or \
kwargs['builder'] is None: kwargs['_builder'] is None:
raise ValueError("Builder argument must be provided outside of JIT functions. Did you forget to add @triton.jit ?") raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return fn(*args, **kwargs) return fn(*args, **kwargs)
if wrapper.__doc__:
wrapper.__doc__ += """\
:param builder: IR builder to generate code into
:type builder: triton.ir.builder, optional from within JIT'ed functions
"""
return wrapper return wrapper
@@ -124,120 +119,120 @@ class block:
self.dtype = block._init_dtype(self.handle.type.scalar) self.dtype = block._init_dtype(self.handle.type.scalar)
@builtin @builtin
def __add__(self, other, builder=None): def __add__(self, other, _builder=None):
return frontend.add(self, other, builder) return frontend.add(self, other, _builder)
def __radd__(self, other, builder=None): def __radd__(self, other, _builder=None):
return self.__add__(other, builder=builder) return self.__add__(other, _builder=_builder)
@builtin @builtin
def __sub__(self, other, builder=None): def __sub__(self, other, _builder=None):
return frontend.sub(self, other, builder) return frontend.sub(self, other, _builder)
def __rsub__(self, other, builder=None): def __rsub__(self, other, _builder=None):
return frontend.sub(other, self, builder) return frontend.sub(other, self, _builder)
@builtin @builtin
def __mul__(self, other, builder=None): def __mul__(self, other, _builder=None):
return frontend.mul(self, other, builder) return frontend.mul(self, other, _builder)
def __rmul__(self, other, builder=None): def __rmul__(self, other, _builder=None):
return self.__mul__(other, builder=builder) return self.__mul__(other, _builder=_builder)
@builtin @builtin
def __truediv__(self, other, builder=None): def __truediv__(self, other, _builder=None):
return frontend.truediv(self, other, builder) return frontend.truediv(self, other, _builder)
def __rtruediv__(self, other, builder=None): def __rtruediv__(self, other, _builder=None):
return frontend.truediv(other, self, builder) return frontend.truediv(other, self, _builder)
@builtin @builtin
def __floordiv__(self, other, builder=None): def __floordiv__(self, other, _builder=None):
return frontend.floordiv(self, other, builder) return frontend.floordiv(self, other, _builder)
@builtin @builtin
def __mod__(self, other, builder=None): def __mod__(self, other, _builder=None):
return frontend.mod(self, other, builder) return frontend.mod(self, other, _builder)
# unary operators # unary operators
@builtin @builtin
def __neg__(self, builder=None): def __neg__(self, _builder=None):
return frontend.minus(self, builder) return frontend.minus(self, _builder)
@builtin @builtin
def __invert__(self, builder=None): def __invert__(self, _builder=None):
return frontend.invert(self, builder) return frontend.invert(self, _builder)
# bitwise operators # bitwise operators
@builtin @builtin
def __and__(self, other, builder=None): def __and__(self, other, _builder=None):
return frontend.and_(self, other, builder) return frontend.and_(self, other, _builder)
@builtin @builtin
def __or__(self, other, builder=None): def __or__(self, other, _builder=None):
return frontend.or_(self, other, builder) return frontend.or_(self, other, _builder)
@builtin @builtin
def __xor__(self, other, builder=None): def __xor__(self, other, _builder=None):
return frontend.xor_(self, other, builder) return frontend.xor_(self, other, _builder)
@builtin @builtin
def __lshift__(self, other, builder=None): def __lshift__(self, other, _builder=None):
return frontend.shl(self, other, builder) return frontend.shl(self, other, _builder)
@builtin @builtin
def __rshift__(self, other, builder=None): def __rshift__(self, other, _builder=None):
return frontend.lshr(self, other, builder) return frontend.lshr(self, other, _builder)
# comparison operators # comparison operators
# > # >
@builtin @builtin
def __gt__(self, other, builder=None): def __gt__(self, other, _builder=None):
return frontend.greater_than(self, other, builder) return frontend.greater_than(self, other, _builder)
@builtin @builtin
def __rgt__(self, other, builder=None): def __rgt__(self, other, _builder=None):
return frontend.greater_than(other, self, builder) return frontend.greater_than(other, self, _builder)
# >= # >=
@builtin @builtin
def __ge__(self, other, builder=None): def __ge__(self, other, _builder=None):
return frontend.greater_equal(self, other, builder) return frontend.greater_equal(self, other, _builder)
def __rge__(self, other, builder=None): def __rge__(self, other, _builder=None):
return frontend.greater_equal(other, self, builder) return frontend.greater_equal(other, self, _builder)
# < # <
@builtin @builtin
def __lt__(self, other, builder=None): def __lt__(self, other, _builder=None):
return frontend.less_than(self, other, builder) return frontend.less_than(self, other, _builder)
@builtin @builtin
def __rlt__(self, other, builder=None): def __rlt__(self, other, _builder=None):
return frontend.less_than(other, self, builder) return frontend.less_than(other, self, _builder)
# <= # <=
@builtin @builtin
def __le__(self, other, builder=None): def __le__(self, other, _builder=None):
return frontend.less_equal(self, other, builder) return frontend.less_equal(self, other, _builder)
@builtin @builtin
def __rle__(self, other, builder=None): def __rle__(self, other, _builder=None):
return frontend.less_equal(other, self, builder) return frontend.less_equal(other, self, _builder)
# == # ==
@builtin @builtin
def __eq__(self, other, builder=None): def __eq__(self, other, _builder=None):
return frontend.equal(self, other, builder) return frontend.equal(self, other, _builder)
@builtin @builtin
def __ne__(self, other, builder=None): def __ne__(self, other, _builder=None):
return frontend.not_equal(self, other, builder) return frontend.not_equal(self, other, _builder)
@builtin @builtin
def __getitem__(self, slices, builder=None): def __getitem__(self, slices, _builder=None):
if isinstance(slices, slice): if isinstance(slices, slice):
slices = [slices] slices = [slices]
src_shape = self.shape src_shape = self.shape
@@ -249,15 +244,15 @@ class block:
elif sl == slice(None, None, None): elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr]) dst_shape.append(src_shape[curr])
curr += 1 curr += 1
ret = frontend.reshape(self, dst_shape, builder) ret = frontend.reshape(self, dst_shape, _builder)
return ret return ret
@builtin @builtin
def to(self, dtype, bitcast=False, builder=None): def to(self, dtype, bitcast=False, _builder=None):
dtype = dtype.handle(builder) dtype = dtype.handle(_builder)
if bitcast: if bitcast:
return frontend.bitcast(self, dtype, builder) return frontend.bitcast(self, dtype, _builder)
return frontend.cast(self, dtype, builder) return frontend.cast(self, dtype, _builder)
# ----------------------- # -----------------------
@@ -266,25 +261,25 @@ class block:
@builtin @builtin
def program_id(axis, builder=None): def program_id(axis, _builder=None):
""" """
Returns the id of the current program instance along the given :code:`axis`. Returns the id of the current program instance along the given :code:`axis`.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int :type axis: int
""" """
return frontend.program_id(axis, builder) return frontend.program_id(axis, _builder)
@builtin @builtin
def num_programs(axis, builder=None): def num_programs(axis, _builder=None):
""" """
Returns the number of program instances launched along the given :code:`axis`. Returns the number of program instances launched along the given :code:`axis`.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int :type axis: int
""" """
return frontend.num_programs(axis, builder) return frontend.num_programs(axis, _builder)
# ----------------------- # -----------------------
@@ -293,7 +288,7 @@ def num_programs(axis, builder=None):
@builtin @builtin
def arange(start, end, builder=None): def arange(start, end, _builder=None):
""" """
Returns contiguous values within the open interval [:code:`start`, :code:`end`). Returns contiguous values within the open interval [:code:`start`, :code:`end`).
@@ -302,11 +297,11 @@ def arange(start, end, builder=None):
:param stop: End of the interval. Must be a power of two >= start. :param stop: End of the interval. Must be a power of two >= start.
:type stop: int :type stop: int
""" """
return frontend.arange(start, end, builder) return frontend.arange(start, end, _builder)
@builtin @builtin
def zeros(shape, dtype, builder=None): def zeros(shape, dtype, _builder=None):
""" """
Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
@@ -316,7 +311,7 @@ def zeros(shape, dtype, builder=None):
:type dtype: DType :type dtype: DType
""" """
shape = [int(x.handle) if isinstance(x, block) else x for x in shape] shape = [int(x.handle) if isinstance(x, block) else x for x in shape]
return frontend.zeros(shape, dtype, builder) return frontend.zeros(shape, dtype, _builder)
# ----------------------- # -----------------------
@@ -325,7 +320,7 @@ def zeros(shape, dtype, builder=None):
@builtin @builtin
def broadcast(input, other, builder=None): def broadcast(input, other, _builder=None):
""" """
Tries to broadcast the two given blocks to a common compatible shape. Tries to broadcast the two given blocks to a common compatible shape.
@@ -334,11 +329,11 @@ def broadcast(input, other, builder=None):
:param other: The second input block. :param other: The second input block.
:type other: Block :type other: Block
""" """
return frontend.broadcast(input, other, builder) return frontend.broadcast(input, other, _builder)
@builtin @builtin
def broadcast_to(input, shape, builder=None): def broadcast_to(input, shape, _builder=None):
""" """
Tries to broadcast the given block to a new :code:`shape`. Tries to broadcast the given block to a new :code:`shape`.
@@ -347,11 +342,11 @@ def broadcast_to(input, shape, builder=None):
:param shape: The desired shape. :param shape: The desired shape.
:type shape: Tuple[int] :type shape: Tuple[int]
""" """
return frontend.broadcast_to(input, shape, builder) return frontend.broadcast_to(input, shape, _builder)
@builtin @builtin
def reshape(input, shape, builder=None): def reshape(input, shape, _builder=None):
""" """
Tries to reshape the given block to a new shape. Tries to reshape the given block to a new shape.
@@ -361,7 +356,7 @@ def reshape(input, shape, builder=None):
:type shape: Tuple[int] :type shape: Tuple[int]
""" """
return frontend.reshape(input, shape, builder) return frontend.reshape(input, shape, _builder)
# ----------------------- # -----------------------
@@ -370,7 +365,7 @@ def reshape(input, shape, builder=None):
@builtin @builtin
def dot(input, other, builder=None): def dot(input, other, _builder=None):
""" """
Returns the matrix product of two blocks. Returns the matrix product of two blocks.
@@ -381,16 +376,16 @@ def dot(input, other, builder=None):
:param other: The second block to be multiplied. :param other: The second block to be multiplied.
:type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`} :type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
""" """
return frontend.dot(input, other, builder) return frontend.dot(input, other, _builder)
# ----------------------- # -----------------------
# Memory Operations # Non-Atomic Memory Operations
# ----------------------- # -----------------------
@builtin @builtin
def load(pointer, mask=None, other=None, builder=None): def load(pointer, mask=None, other=None, _builder=None):
""" """
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
@@ -405,11 +400,11 @@ def load(pointer, mask=None, other=None, builder=None):
:param other: if mask[idx] is false, return other[idx] :param other: if mask[idx] is false, return other[idx]
:type other: Block, optional :type other: Block, optional
""" """
return frontend.load(pointer, mask, other, builder) return frontend.load(pointer, mask, other, _builder)
@builtin @builtin
def store(pointer, value, mask=None, builder=None): def store(pointer, value, mask=None, _builder=None):
""" """
Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
@@ -422,13 +417,20 @@ def store(pointer, value, mask=None, builder=None):
:param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`. :param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`.
:type mask: Block of triton.int1, optional :type mask: Block of triton.int1, optional
""" """
return frontend.store(pointer, value, mask, builder) return frontend.store(pointer, value, mask, _builder)
@builtin # -----------------------
def atomic_cas(pointer, cmp, val, builder=None): # Atomic Memory Operations
""" # -----------------------
Performs an atomic "compare-and-swap" and the memory locations specified by :code:`pointer`.
def _add_atomic_docstr(name):
def _decorator(func):
docstr = """
Performs an atomic {name} at the memory location specified by :code:`pointer`.
Return the data stored at :code:`pointer` before the atomic operation.
:param pointer: The memory locations to compare-and-swap. :param pointer: The memory locations to compare-and-swap.
:type pointer: Block of dtype=triton.PointerDType :type pointer: Block of dtype=triton.PointerDType
@@ -437,62 +439,56 @@ def atomic_cas(pointer, cmp, val, builder=None):
:param val: The values to copy in case the expected value matches the contained value. :param val: The values to copy in case the expected value matches the contained value.
:type val: Block of dtype=`pointer.dtype.element_ty` :type val: Block of dtype=`pointer.dtype.element_ty`
""" """
func.__doc__ = docstr.format(name=name)
return func
return frontend.atomic_cas(pointer, cmp, val, builder) return _decorator
@builtin
@_add_atomic_docstr("compare-and-swap")
def atomic_cas(pointer, cmp, val, _builder=None):
return frontend.atomic_cas(pointer, cmp, val, _builder)
@builtin @builtin
def atomic_xchg(pointer, val, mask=None, builder=None): @_add_atomic_docstr("exchange")
""" def atomic_xchg(pointer, val, mask=None, _builder=None):
Swaps the *old* values stored at location :code:`pointer` with the new values given by :code:`val`. Returns the old values. return frontend.atomic_xchg(pointer, val, mask, _builder)
:param pointer: The memory locations which contain the old values @builtin
:type pointer: Block of dtype=triton.PointerDType @_add_atomic_docstr("add")
:param val: The new values to store def atomic_add(pointer, val, mask=None, _builder=None):
:type val: Block of dtype=`pointer.dtype.element_ty` return frontend.atomic_add(pointer, val, mask, _builder)
:param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected.
:type mask: Block of triton.int1, optional
"""
return frontend.atomic_xchg(pointer, val, mask, builder)
@builtin @builtin
def atomic_add(pointer, val, mask=None, builder=None): @_add_atomic_docstr("max")
""" def atomic_max(pointer, val, mask=None, _builder=None):
Performs an atomic add and the memory locations specified by :code:`pointer`. return frontend.atomic_max(pointer, val, mask, _builder)
:param pointer: The memory locations which contain the old values
:type pointer: Block of dtype=triton.PointerDType
:param val: The values to add
:type val: Block of dtype=`pointer.dtype.element_ty`
:param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected.
:type mask: Block of triton.int1, optional
"""
return frontend.atomic_add(pointer, val, mask, builder)
@builtin @builtin
def atomic_max(pointer, val, mask=None, builder=None): @_add_atomic_docstr("min")
return frontend.atomic_max(pointer, val, mask, builder) def atomic_min(pointer, val, mask=None, _builder=None):
return frontend.atomic_min(pointer, val, mask, _builder)
@builtin @builtin
def atomic_min(pointer, val, mask=None, builder=None): @_add_atomic_docstr("logical and")
return frontend.atomic_min(pointer, val, mask, builder) def atomic_and(pointer, val, mask=None, _builder=None):
return frontend.atomic_and(pointer, val, mask, _builder)
@builtin @builtin
def atomic_and(pointer, val, mask=None, builder=None): @_add_atomic_docstr("logical or")
return frontend.atomic_and(pointer, val, mask, builder) def atomic_or(pointer, val, mask=None, _builder=None):
return frontend.atomic_or(pointer, val, mask, _builder)
@builtin @builtin
def atomic_or(pointer, val, mask=None, builder=None): @_add_atomic_docstr("logical xor")
return frontend.atomic_or(pointer, val, mask, builder) def atomic_xor(pointer, val, mask=None, _builder=None):
return frontend.atomic_xor(pointer, val, mask, _builder)
@builtin
def atomic_xor(pointer, val, mask=None, builder=None):
return frontend.atomic_xor(pointer, val, mask, builder)
# ----------------------- # -----------------------
@@ -501,7 +497,7 @@ def atomic_xor(pointer, val, mask=None, builder=None):
@builtin @builtin
def where(condition, x, y, builder=None): def where(condition, x, y, _builder=None):
""" """
Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
@@ -517,108 +513,89 @@ def where(condition, x, y, builder=None):
:param x: values selected at indices where condition is True. :param x: values selected at indices where condition is True.
:param y: values selected at indices where condition is False. :param y: values selected at indices where condition is False.
""" """
return frontend.where(condition, x, y, builder) return frontend.where(condition, x, y, _builder)
# ----------------------- # -----------------------
# Math # Math
# ----------------------- # -----------------------
def _add_math_1arg_docstr(name):
@builtin def _decorator(func):
def exp(x, builder=None): docstr = """
""" Computes the element-wise {name} of :code:`x`
Computes the element-wise exponential of :code:`x`
:param x: the input values :param x: the input values
:type x: Block :type x: Block
""" """
func.__doc__ = docstr.format(name=name)
return func
return frontend.exp(x, builder) return _decorator
@builtin
@_add_math_1arg_docstr("exponential")
def exp(x, _builder=None):
return frontend.exp(x, _builder)
@builtin @builtin
def log(x, builder=None): @_add_math_1arg_docstr("natural logarithm")
""" def log(x, _builder=None):
Computes the element-wise natural logarithm of :code:`x` return frontend.log(x, _builder)
:param x: the input values
:type x: Block
"""
return frontend.log(x, builder)
@builtin @builtin
def cos(x, builder=None): @_add_math_1arg_docstr("cosine")
""" def cos(x, _builder=None):
Computes the element-wise cosine of :code:`x` return frontend.cos(x, _builder)
:param x: the input values
:type x: Block
"""
return frontend.cos(x, builder)
@builtin @builtin
def sin(x, builder=None): @_add_math_1arg_docstr("sine")
""" def sin(x, _builder=None):
Computes the element-wise sine of :code:`x` return frontend.sin(x, _builder)
:param x: the input values
:type x: Block
"""
return frontend.sin(x, builder)
@builtin @builtin
def sqrt(x, builder=None): @_add_math_1arg_docstr("square root")
""" def sqrt(x, _builder=None):
Computes the element-wise square root of :code:`x` return frontend.sqrt(x, _builder)
:param x: the input values
:type x: Block
"""
return frontend.sqrt(x, builder)
# ----------------------- # -----------------------
# Reductions # Reductions
# ----------------------- # -----------------------
def _add_reduction_docstr(name):
@builtin def _decorator(func):
def max(input, axis, builder=None): docstr = """
""" Returns the {name} of all elements in the :code:`input` block along the provided :code:`axis`
Returns the maximum value of all elements in the :code:`input` block along the provided :code:`axis`
:param input: the input values :param input: the input values
:param axis: the dimension along which the reduction should be done :param axis: the dimension along which the reduction should be done
""" """
return frontend.max(input, axis, builder) func.__doc__ = docstr.format(name=name)
return func
return _decorator
@builtin
@_add_reduction_docstr("maximum")
def max(input, axis, _builder=None):
return frontend.max(input, axis, _builder)
@builtin @builtin
def min(input, axis, builder=None): @_add_reduction_docstr("minimum")
""" def min(input, axis, _builder=None):
Returns the minimum value of all elements in the :code:`input` block along the provided :code:`axis` return frontend.min(input, axis, _builder)
:param input: the input values
:param axis: the dimension along which the reduction should be done
"""
return frontend.min(input, axis, builder)
@builtin @builtin
def sum(input, axis, builder=None): @_add_reduction_docstr("sum")
""" def sum(input, axis, _builder=None):
Returns the sum of all elements in the :code:`input` block along the provided :code:`axis` return frontend.sum(input, axis, _builder)
:param input: the input values
:param axis: the dimension along which the reduction should be done
"""
return frontend.sum(input, axis, builder)
# ----------------------- # -----------------------
@@ -627,24 +604,24 @@ def sum(input, axis, builder=None):
@builtin @builtin
def debug_barrier(builder=None): def debug_barrier(_builder=None):
return frontend.debug_barrier(builder) return frontend.debug_barrier(_builder)
@builtin @builtin
def multiple_of(input, value, builder=None): def multiple_of(input, value, _builder=None):
""" """
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
""" """
return frontend.multiple_of(input, value, builder) return frontend.multiple_of(input, value, _builder)
@builtin @builtin
def max_contiguous(input, value, builder=None): def max_contiguous(input, value, _builder=None):
""" """
Let the compiler knows that the `value` first values in :code:`input` are contiguous. Let the compiler knows that the `value` first values in :code:`input` are contiguous.
""" """
return frontend.max_contiguous(input, value, builder) return frontend.max_contiguous(input, value, _builder)
# ----------------------- # -----------------------
@@ -690,24 +667,14 @@ def maximum(x, y):
@triton.jit @triton.jit
@_add_math_1arg_docstr("sigmoid")
def sigmoid(x): def sigmoid(x):
"""
Computes the element-wise sigmoid of :code:`x`.
:param x: the input block
:type x: Block
"""
return 1 / (1 + triton.language.exp(-x)) return 1 / (1 + triton.language.exp(-x))
@triton.jit @triton.jit
@_add_math_1arg_docstr("softmax")
def softmax(x): def softmax(x):
"""
Computes the element-wise softmax of :code:`x`.
:param x: the input block
:type x: Block
"""
z = x - triton.language.max(x, 0) z = x - triton.language.max(x, 0)
num = triton.language.exp(z) num = triton.language.exp(z)
den = triton.language.sum(num, 0) den = triton.language.sum(num, 0)

View File

@@ -322,7 +322,7 @@ else:
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[ x_vals=[
128 * i for i in range(1, 33) 128 * i for i in range(2, 33)
], # different possible values for `x_name` ], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot line_arg='provider', # argument name whose value corresponds to a different line in the plot
# possible values for `line_arg`` # possible values for `line_arg``