[DOCS] Various improvements (#224)

- Added docstr for autotune, Config, heuristics
- Added docstr for atomics
- Hiding internal _builder argument used for built-in language primitives
- Re-factor docstr to use common templates between similar functions.
This commit is contained in:
Philippe Tillet
2021-08-18 11:15:53 -07:00
committed by GitHub
parent 226fde6ea1
commit f26a48a3b4
6 changed files with 275 additions and 222 deletions

View File

@@ -193,11 +193,11 @@ class CodeGenerator(ast.NodeVisitor):
kws = dict()
if self.is_triton_object(lhs):
kws['builder'] = self.builder
kws['_builder'] = self.builder
ret = getattr(lhs, fn)(rhs, **kws)
if ret is NotImplemented:
if self.is_triton_object(rhs):
kws['builder'] = self.builder
kws['_builder'] = self.builder
fn = fn[:2] + 'r' + fn[2:]
ret = getattr(rhs, fn)(lhs, **kws)
return ret
@@ -260,10 +260,10 @@ class CodeGenerator(ast.NodeVisitor):
ast.IsNot: '__ne__',
}[type(node.ops[0])]
if self.is_triton_object(lhs):
return getattr(lhs, fn)(rhs, builder=self.builder)
return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_object(rhs):
fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, builder=self.builder)
return getattr(rhs, fn)(lhs, _builder=self.builder)
else:
return getattr(lhs, fn)(rhs)
@@ -275,7 +275,7 @@ class CodeGenerator(ast.NodeVisitor):
ast.Invert: '__invert__',
}[type(node.op)]
if self.is_triton_object(op):
return getattr(op, fn)(builder=self.builder)
return getattr(op, fn)(_builder=self.builder)
return getattr(op, fn)()
def visit_While(self, node):
@@ -308,7 +308,7 @@ class CodeGenerator(ast.NodeVisitor):
lhs = self.visit(node.value)
slices = self.visit(node.slice)
if self.is_triton_object(lhs):
return lhs.__getitem__(slices, builder=self.builder)
return lhs.__getitem__(slices, _builder=self.builder)
return lhs[slices]
def visit_ExtSlice(self, node):
@@ -331,7 +331,7 @@ class CodeGenerator(ast.NodeVisitor):
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
self.visit(pos_cond_node),\
self.visit(neg_cond_node),\
builder=self.builder)
_builder=self.builder)
#cond_node = neg_cond_node
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
# code generation
@@ -385,7 +385,7 @@ class CodeGenerator(ast.NodeVisitor):
return fn(*args, generator=self, **kws)
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
sys.modules[fn.__module__] is triton.language:
return fn(*args, builder=self.builder, **kws)
return fn(*args, _builder=self.builder, **kws)
return fn(*args, **kws)
def visit_Num(self, node):
@@ -714,6 +714,19 @@ class JITFunction:
class Config:
"""
An object that represents a possible kernel configuration for the auto-tuner to try.
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
:type meta: dict[Str, Any]
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
`num_warps=8`, then each kernel instance will be automatically parallelized to
cooperatively execute using `8 * 32 = 256` threads.
:type num_warps: int
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
:type num_stages: int
"""
def __init__(self, meta, num_warps=4, num_stages=2):
self.meta = meta
self.num_warps = num_warps
@@ -721,6 +734,35 @@ class Config:
def autotune(configs, key, reset_to_zero=None):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
def wrapper(kernel):
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero)
@@ -732,6 +774,23 @@ def autotune(configs, key, reset_to_zero=None):
def heuristics(values):
"""
Decorator for specifying how the values of certain meta-parameters may be computed.
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
.. highlight:: python
.. code-block:: python
@heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
each such function takes a list of positional arguments as input.
.type values: dict[str, Callable[[list[Any]], Any]]
"""
def decorator(fn):
def wrapper(kernel):
def fun(*args, **meta):
@@ -767,6 +826,8 @@ def jit(fn):
return JITFunction(fn)
######
def cdiv(x, y):
return (x + y - 1) // y

View File

@@ -49,16 +49,11 @@ for name in dir(frontend):
def builtin(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if 'builder' not in kwargs or \
kwargs['builder'] is None:
raise ValueError("Builder argument must be provided outside of JIT functions. Did you forget to add @triton.jit ?")
if '_builder' not in kwargs or \
kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return fn(*args, **kwargs)
if wrapper.__doc__:
wrapper.__doc__ += """\
:param builder: IR builder to generate code into
:type builder: triton.ir.builder, optional from within JIT'ed functions
"""
return wrapper
@@ -124,120 +119,120 @@ class block:
self.dtype = block._init_dtype(self.handle.type.scalar)
@builtin
def __add__(self, other, builder=None):
return frontend.add(self, other, builder)
def __add__(self, other, _builder=None):
return frontend.add(self, other, _builder)
def __radd__(self, other, builder=None):
return self.__add__(other, builder=builder)
def __radd__(self, other, _builder=None):
return self.__add__(other, _builder=_builder)
@builtin
def __sub__(self, other, builder=None):
return frontend.sub(self, other, builder)
def __sub__(self, other, _builder=None):
return frontend.sub(self, other, _builder)
def __rsub__(self, other, builder=None):
return frontend.sub(other, self, builder)
def __rsub__(self, other, _builder=None):
return frontend.sub(other, self, _builder)
@builtin
def __mul__(self, other, builder=None):
return frontend.mul(self, other, builder)
def __mul__(self, other, _builder=None):
return frontend.mul(self, other, _builder)
def __rmul__(self, other, builder=None):
return self.__mul__(other, builder=builder)
def __rmul__(self, other, _builder=None):
return self.__mul__(other, _builder=_builder)
@builtin
def __truediv__(self, other, builder=None):
return frontend.truediv(self, other, builder)
def __truediv__(self, other, _builder=None):
return frontend.truediv(self, other, _builder)
def __rtruediv__(self, other, builder=None):
return frontend.truediv(other, self, builder)
def __rtruediv__(self, other, _builder=None):
return frontend.truediv(other, self, _builder)
@builtin
def __floordiv__(self, other, builder=None):
return frontend.floordiv(self, other, builder)
def __floordiv__(self, other, _builder=None):
return frontend.floordiv(self, other, _builder)
@builtin
def __mod__(self, other, builder=None):
return frontend.mod(self, other, builder)
def __mod__(self, other, _builder=None):
return frontend.mod(self, other, _builder)
# unary operators
@builtin
def __neg__(self, builder=None):
return frontend.minus(self, builder)
def __neg__(self, _builder=None):
return frontend.minus(self, _builder)
@builtin
def __invert__(self, builder=None):
return frontend.invert(self, builder)
def __invert__(self, _builder=None):
return frontend.invert(self, _builder)
# bitwise operators
@builtin
def __and__(self, other, builder=None):
return frontend.and_(self, other, builder)
def __and__(self, other, _builder=None):
return frontend.and_(self, other, _builder)
@builtin
def __or__(self, other, builder=None):
return frontend.or_(self, other, builder)
def __or__(self, other, _builder=None):
return frontend.or_(self, other, _builder)
@builtin
def __xor__(self, other, builder=None):
return frontend.xor_(self, other, builder)
def __xor__(self, other, _builder=None):
return frontend.xor_(self, other, _builder)
@builtin
def __lshift__(self, other, builder=None):
return frontend.shl(self, other, builder)
def __lshift__(self, other, _builder=None):
return frontend.shl(self, other, _builder)
@builtin
def __rshift__(self, other, builder=None):
return frontend.lshr(self, other, builder)
def __rshift__(self, other, _builder=None):
return frontend.lshr(self, other, _builder)
# comparison operators
# >
@builtin
def __gt__(self, other, builder=None):
return frontend.greater_than(self, other, builder)
def __gt__(self, other, _builder=None):
return frontend.greater_than(self, other, _builder)
@builtin
def __rgt__(self, other, builder=None):
return frontend.greater_than(other, self, builder)
def __rgt__(self, other, _builder=None):
return frontend.greater_than(other, self, _builder)
# >=
@builtin
def __ge__(self, other, builder=None):
return frontend.greater_equal(self, other, builder)
def __ge__(self, other, _builder=None):
return frontend.greater_equal(self, other, _builder)
def __rge__(self, other, builder=None):
return frontend.greater_equal(other, self, builder)
def __rge__(self, other, _builder=None):
return frontend.greater_equal(other, self, _builder)
# <
@builtin
def __lt__(self, other, builder=None):
return frontend.less_than(self, other, builder)
def __lt__(self, other, _builder=None):
return frontend.less_than(self, other, _builder)
@builtin
def __rlt__(self, other, builder=None):
return frontend.less_than(other, self, builder)
def __rlt__(self, other, _builder=None):
return frontend.less_than(other, self, _builder)
# <=
@builtin
def __le__(self, other, builder=None):
return frontend.less_equal(self, other, builder)
def __le__(self, other, _builder=None):
return frontend.less_equal(self, other, _builder)
@builtin
def __rle__(self, other, builder=None):
return frontend.less_equal(other, self, builder)
def __rle__(self, other, _builder=None):
return frontend.less_equal(other, self, _builder)
# ==
@builtin
def __eq__(self, other, builder=None):
return frontend.equal(self, other, builder)
def __eq__(self, other, _builder=None):
return frontend.equal(self, other, _builder)
@builtin
def __ne__(self, other, builder=None):
return frontend.not_equal(self, other, builder)
def __ne__(self, other, _builder=None):
return frontend.not_equal(self, other, _builder)
@builtin
def __getitem__(self, slices, builder=None):
def __getitem__(self, slices, _builder=None):
if isinstance(slices, slice):
slices = [slices]
src_shape = self.shape
@@ -249,15 +244,15 @@ class block:
elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr])
curr += 1
ret = frontend.reshape(self, dst_shape, builder)
ret = frontend.reshape(self, dst_shape, _builder)
return ret
@builtin
def to(self, dtype, bitcast=False, builder=None):
dtype = dtype.handle(builder)
def to(self, dtype, bitcast=False, _builder=None):
dtype = dtype.handle(_builder)
if bitcast:
return frontend.bitcast(self, dtype, builder)
return frontend.cast(self, dtype, builder)
return frontend.bitcast(self, dtype, _builder)
return frontend.cast(self, dtype, _builder)
# -----------------------
@@ -266,25 +261,25 @@ class block:
@builtin
def program_id(axis, builder=None):
def program_id(axis, _builder=None):
"""
Returns the id of the current program instance along the given :code:`axis`.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
"""
return frontend.program_id(axis, builder)
return frontend.program_id(axis, _builder)
@builtin
def num_programs(axis, builder=None):
def num_programs(axis, _builder=None):
"""
Returns the number of program instances launched along the given :code:`axis`.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
"""
return frontend.num_programs(axis, builder)
return frontend.num_programs(axis, _builder)
# -----------------------
@@ -293,7 +288,7 @@ def num_programs(axis, builder=None):
@builtin
def arange(start, end, builder=None):
def arange(start, end, _builder=None):
"""
Returns contiguous values within the open interval [:code:`start`, :code:`end`).
@@ -302,11 +297,11 @@ def arange(start, end, builder=None):
:param stop: End of the interval. Must be a power of two >= start.
:type stop: int
"""
return frontend.arange(start, end, builder)
return frontend.arange(start, end, _builder)
@builtin
def zeros(shape, dtype, builder=None):
def zeros(shape, dtype, _builder=None):
"""
Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
@@ -316,7 +311,7 @@ def zeros(shape, dtype, builder=None):
:type dtype: DType
"""
shape = [int(x.handle) if isinstance(x, block) else x for x in shape]
return frontend.zeros(shape, dtype, builder)
return frontend.zeros(shape, dtype, _builder)
# -----------------------
@@ -325,7 +320,7 @@ def zeros(shape, dtype, builder=None):
@builtin
def broadcast(input, other, builder=None):
def broadcast(input, other, _builder=None):
"""
Tries to broadcast the two given blocks to a common compatible shape.
@@ -334,11 +329,11 @@ def broadcast(input, other, builder=None):
:param other: The second input block.
:type other: Block
"""
return frontend.broadcast(input, other, builder)
return frontend.broadcast(input, other, _builder)
@builtin
def broadcast_to(input, shape, builder=None):
def broadcast_to(input, shape, _builder=None):
"""
Tries to broadcast the given block to a new :code:`shape`.
@@ -347,11 +342,11 @@ def broadcast_to(input, shape, builder=None):
:param shape: The desired shape.
:type shape: Tuple[int]
"""
return frontend.broadcast_to(input, shape, builder)
return frontend.broadcast_to(input, shape, _builder)
@builtin
def reshape(input, shape, builder=None):
def reshape(input, shape, _builder=None):
"""
Tries to reshape the given block to a new shape.
@@ -361,7 +356,7 @@ def reshape(input, shape, builder=None):
:type shape: Tuple[int]
"""
return frontend.reshape(input, shape, builder)
return frontend.reshape(input, shape, _builder)
# -----------------------
@@ -370,7 +365,7 @@ def reshape(input, shape, builder=None):
@builtin
def dot(input, other, builder=None):
def dot(input, other, _builder=None):
"""
Returns the matrix product of two blocks.
@@ -381,16 +376,16 @@ def dot(input, other, builder=None):
:param other: The second block to be multiplied.
:type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
"""
return frontend.dot(input, other, builder)
return frontend.dot(input, other, _builder)
# -----------------------
# Memory Operations
# Non-Atomic Memory Operations
# -----------------------
@builtin
def load(pointer, mask=None, other=None, builder=None):
def load(pointer, mask=None, other=None, _builder=None):
"""
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
@@ -405,11 +400,11 @@ def load(pointer, mask=None, other=None, builder=None):
:param other: if mask[idx] is false, return other[idx]
:type other: Block, optional
"""
return frontend.load(pointer, mask, other, builder)
return frontend.load(pointer, mask, other, _builder)
@builtin
def store(pointer, value, mask=None, builder=None):
def store(pointer, value, mask=None, _builder=None):
"""
Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
@@ -422,13 +417,20 @@ def store(pointer, value, mask=None, builder=None):
:param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`.
:type mask: Block of triton.int1, optional
"""
return frontend.store(pointer, value, mask, builder)
return frontend.store(pointer, value, mask, _builder)
@builtin
def atomic_cas(pointer, cmp, val, builder=None):
"""
Performs an atomic "compare-and-swap" and the memory locations specified by :code:`pointer`.
# -----------------------
# Atomic Memory Operations
# -----------------------
def _add_atomic_docstr(name):
def _decorator(func):
docstr = """
Performs an atomic {name} at the memory location specified by :code:`pointer`.
Return the data stored at :code:`pointer` before the atomic operation.
:param pointer: The memory locations to compare-and-swap.
:type pointer: Block of dtype=triton.PointerDType
@@ -437,62 +439,56 @@ def atomic_cas(pointer, cmp, val, builder=None):
:param val: The values to copy in case the expected value matches the contained value.
:type val: Block of dtype=`pointer.dtype.element_ty`
"""
return frontend.atomic_cas(pointer, cmp, val, builder)
func.__doc__ = docstr.format(name=name)
return func
return _decorator
@builtin
@_add_atomic_docstr("compare-and-swap")
def atomic_cas(pointer, cmp, val, _builder=None):
return frontend.atomic_cas(pointer, cmp, val, _builder)
@builtin
def atomic_xchg(pointer, val, mask=None, builder=None):
"""
Swaps the *old* values stored at location :code:`pointer` with the new values given by :code:`val`. Returns the old values.
@_add_atomic_docstr("exchange")
def atomic_xchg(pointer, val, mask=None, _builder=None):
return frontend.atomic_xchg(pointer, val, mask, _builder)
:param pointer: The memory locations which contain the old values
:type pointer: Block of dtype=triton.PointerDType
:param val: The new values to store
:type val: Block of dtype=`pointer.dtype.element_ty`
:param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected.
:type mask: Block of triton.int1, optional
"""
return frontend.atomic_xchg(pointer, val, mask, builder)
@builtin
@_add_atomic_docstr("add")
def atomic_add(pointer, val, mask=None, _builder=None):
return frontend.atomic_add(pointer, val, mask, _builder)
@builtin
def atomic_add(pointer, val, mask=None, builder=None):
"""
Performs an atomic add and the memory locations specified by :code:`pointer`.
:param pointer: The memory locations which contain the old values
:type pointer: Block of dtype=triton.PointerDType
:param val: The values to add
:type val: Block of dtype=`pointer.dtype.element_ty`
:param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected.
:type mask: Block of triton.int1, optional
"""
return frontend.atomic_add(pointer, val, mask, builder)
@_add_atomic_docstr("max")
def atomic_max(pointer, val, mask=None, _builder=None):
return frontend.atomic_max(pointer, val, mask, _builder)
@builtin
def atomic_max(pointer, val, mask=None, builder=None):
return frontend.atomic_max(pointer, val, mask, builder)
@_add_atomic_docstr("min")
def atomic_min(pointer, val, mask=None, _builder=None):
return frontend.atomic_min(pointer, val, mask, _builder)
@builtin
def atomic_min(pointer, val, mask=None, builder=None):
return frontend.atomic_min(pointer, val, mask, builder)
@_add_atomic_docstr("logical and")
def atomic_and(pointer, val, mask=None, _builder=None):
return frontend.atomic_and(pointer, val, mask, _builder)
@builtin
def atomic_and(pointer, val, mask=None, builder=None):
return frontend.atomic_and(pointer, val, mask, builder)
@_add_atomic_docstr("logical or")
def atomic_or(pointer, val, mask=None, _builder=None):
return frontend.atomic_or(pointer, val, mask, _builder)
@builtin
def atomic_or(pointer, val, mask=None, builder=None):
return frontend.atomic_or(pointer, val, mask, builder)
@builtin
def atomic_xor(pointer, val, mask=None, builder=None):
return frontend.atomic_xor(pointer, val, mask, builder)
@_add_atomic_docstr("logical xor")
def atomic_xor(pointer, val, mask=None, _builder=None):
return frontend.atomic_xor(pointer, val, mask, _builder)
# -----------------------
@@ -501,7 +497,7 @@ def atomic_xor(pointer, val, mask=None, builder=None):
@builtin
def where(condition, x, y, builder=None):
def where(condition, x, y, _builder=None):
"""
Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
@@ -517,108 +513,89 @@ def where(condition, x, y, builder=None):
:param x: values selected at indices where condition is True.
:param y: values selected at indices where condition is False.
"""
return frontend.where(condition, x, y, builder)
return frontend.where(condition, x, y, _builder)
# -----------------------
# Math
# -----------------------
def _add_math_1arg_docstr(name):
@builtin
def exp(x, builder=None):
"""
Computes the element-wise exponential of :code:`x`
def _decorator(func):
docstr = """
Computes the element-wise {name} of :code:`x`
:param x: the input values
:type x: Block
"""
func.__doc__ = docstr.format(name=name)
return func
return _decorator
return frontend.exp(x, builder)
@builtin
@_add_math_1arg_docstr("exponential")
def exp(x, _builder=None):
return frontend.exp(x, _builder)
@builtin
def log(x, builder=None):
"""
Computes the element-wise natural logarithm of :code:`x`
:param x: the input values
:type x: Block
"""
return frontend.log(x, builder)
@_add_math_1arg_docstr("natural logarithm")
def log(x, _builder=None):
return frontend.log(x, _builder)
@builtin
def cos(x, builder=None):
"""
Computes the element-wise cosine of :code:`x`
:param x: the input values
:type x: Block
"""
return frontend.cos(x, builder)
@_add_math_1arg_docstr("cosine")
def cos(x, _builder=None):
return frontend.cos(x, _builder)
@builtin
def sin(x, builder=None):
"""
Computes the element-wise sine of :code:`x`
:param x: the input values
:type x: Block
"""
return frontend.sin(x, builder)
@_add_math_1arg_docstr("sine")
def sin(x, _builder=None):
return frontend.sin(x, _builder)
@builtin
def sqrt(x, builder=None):
"""
Computes the element-wise square root of :code:`x`
:param x: the input values
:type x: Block
"""
return frontend.sqrt(x, builder)
@_add_math_1arg_docstr("square root")
def sqrt(x, _builder=None):
return frontend.sqrt(x, _builder)
# -----------------------
# Reductions
# -----------------------
def _add_reduction_docstr(name):
@builtin
def max(input, axis, builder=None):
"""
Returns the maximum value of all elements in the :code:`input` block along the provided :code:`axis`
def _decorator(func):
docstr = """
Returns the {name} of all elements in the :code:`input` block along the provided :code:`axis`
:param input: the input values
:param axis: the dimension along which the reduction should be done
"""
return frontend.max(input, axis, builder)
func.__doc__ = docstr.format(name=name)
return func
return _decorator
@builtin
@_add_reduction_docstr("maximum")
def max(input, axis, _builder=None):
return frontend.max(input, axis, _builder)
@builtin
def min(input, axis, builder=None):
"""
Returns the minimum value of all elements in the :code:`input` block along the provided :code:`axis`
:param input: the input values
:param axis: the dimension along which the reduction should be done
"""
return frontend.min(input, axis, builder)
@_add_reduction_docstr("minimum")
def min(input, axis, _builder=None):
return frontend.min(input, axis, _builder)
@builtin
def sum(input, axis, builder=None):
"""
Returns the sum of all elements in the :code:`input` block along the provided :code:`axis`
:param input: the input values
:param axis: the dimension along which the reduction should be done
"""
return frontend.sum(input, axis, builder)
@_add_reduction_docstr("sum")
def sum(input, axis, _builder=None):
return frontend.sum(input, axis, _builder)
# -----------------------
@@ -627,24 +604,24 @@ def sum(input, axis, builder=None):
@builtin
def debug_barrier(builder=None):
return frontend.debug_barrier(builder)
def debug_barrier(_builder=None):
return frontend.debug_barrier(_builder)
@builtin
def multiple_of(input, value, builder=None):
def multiple_of(input, value, _builder=None):
"""
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
"""
return frontend.multiple_of(input, value, builder)
return frontend.multiple_of(input, value, _builder)
@builtin
def max_contiguous(input, value, builder=None):
def max_contiguous(input, value, _builder=None):
"""
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
"""
return frontend.max_contiguous(input, value, builder)
return frontend.max_contiguous(input, value, _builder)
# -----------------------
@@ -690,24 +667,14 @@ def maximum(x, y):
@triton.jit
@_add_math_1arg_docstr("sigmoid")
def sigmoid(x):
"""
Computes the element-wise sigmoid of :code:`x`.
:param x: the input block
:type x: Block
"""
return 1 / (1 + triton.language.exp(-x))
@triton.jit
@_add_math_1arg_docstr("softmax")
def softmax(x):
"""
Computes the element-wise softmax of :code:`x`.
:param x: the input block
:type x: Block
"""
z = x - triton.language.max(x, 0)
num = triton.language.exp(z)
den = triton.language.sum(num, 0)

View File

@@ -322,7 +322,7 @@ else:
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[
128 * i for i in range(1, 33)
128 * i for i in range(2, 33)
], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``