From f26a48a3b408b219dbf3e187ef14b7f71fd59de5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 18 Aug 2021 11:15:53 -0700 Subject: [PATCH] [DOCS] Various improvements (#224) - Added docstr for autotune, Config, heuristics - Added docstr for atomics - Hiding internal _builder argument used for built-in language primitives - Re-factor docstr to use common templates between similar functions. --- docs/conf.py | 10 + docs/python-api/triton.language.rst | 12 + docs/python-api/triton.rst | 5 +- python/triton/code_gen.py | 77 +++- python/triton/language.py | 391 +++++++++---------- python/tutorials/03-matrix-multiplication.py | 2 +- 6 files changed, 275 insertions(+), 222 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 7a8d8c3f5..1107ef171 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -24,11 +24,20 @@ # -- General configuration ------------------------------------------------ + + +def process_sig(app, what, name, obj, options, signature, return_annotation): + if signature and '_builder' in signature: + signature = signature.split('_builder')[0] + ")" + return (signature, return_annotation) + def setup(app): """Customize function args retrieving to get args under decorator.""" import sphinx import triton + app.connect("autodoc-process-signature", process_sig) + def forward_jit_fn(func): old = func @@ -39,6 +48,7 @@ def setup(app): return wrapped + old_documenter = sphinx.ext.autosummary.get_documenter def documenter(app, obj, parent): diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index b86ee798f..4cb437faf 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -98,6 +98,18 @@ Reduction Ops min sum +Atomic Ops +--------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + atomic_cas + atomic_add + atomic_max + atomic_min + Comparison ops --------------- diff --git a/docs/python-api/triton.rst b/docs/python-api/triton.rst index 2db99da77..70bda4931 100644 --- a/docs/python-api/triton.rst +++ b/docs/python-api/triton.rst @@ -7,4 +7,7 @@ triton :toctree: generated :nosignatures: - jit \ No newline at end of file + jit + autotune + heuristics + Config \ No newline at end of file diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 7b9d5b037..0172d5220 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -193,11 +193,11 @@ class CodeGenerator(ast.NodeVisitor): kws = dict() if self.is_triton_object(lhs): - kws['builder'] = self.builder + kws['_builder'] = self.builder ret = getattr(lhs, fn)(rhs, **kws) if ret is NotImplemented: if self.is_triton_object(rhs): - kws['builder'] = self.builder + kws['_builder'] = self.builder fn = fn[:2] + 'r' + fn[2:] ret = getattr(rhs, fn)(lhs, **kws) return ret @@ -260,10 +260,10 @@ class CodeGenerator(ast.NodeVisitor): ast.IsNot: '__ne__', }[type(node.ops[0])] if self.is_triton_object(lhs): - return getattr(lhs, fn)(rhs, builder=self.builder) + return getattr(lhs, fn)(rhs, _builder=self.builder) elif self.is_triton_object(rhs): fn = fn[:2] + 'r' + fn[2:] - return getattr(rhs, fn)(lhs, builder=self.builder) + return getattr(rhs, fn)(lhs, _builder=self.builder) else: return getattr(lhs, fn)(rhs) @@ -275,7 +275,7 @@ class CodeGenerator(ast.NodeVisitor): ast.Invert: '__invert__', }[type(node.op)] if self.is_triton_object(op): - return getattr(op, fn)(builder=self.builder) + return getattr(op, fn)(_builder=self.builder) return getattr(op, fn)() def visit_While(self, node): @@ -308,7 +308,7 @@ class CodeGenerator(ast.NodeVisitor): lhs = self.visit(node.value) slices = self.visit(node.slice) if self.is_triton_object(lhs): - return lhs.__getitem__(slices, builder=self.builder) + return lhs.__getitem__(slices, _builder=self.builder) return lhs[slices] def visit_ExtSlice(self, node): @@ -331,7 +331,7 @@ class CodeGenerator(ast.NodeVisitor): build_cond = lambda: triton.language.where(self.visit(pos_step_node),\ self.visit(pos_cond_node),\ self.visit(neg_cond_node),\ - builder=self.builder) + _builder=self.builder) #cond_node = neg_cond_node step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) # code generation @@ -385,7 +385,7 @@ class CodeGenerator(ast.NodeVisitor): return fn(*args, generator=self, **kws) if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ sys.modules[fn.__module__] is triton.language: - return fn(*args, builder=self.builder, **kws) + return fn(*args, _builder=self.builder, **kws) return fn(*args, **kws) def visit_Num(self, node): @@ -714,6 +714,19 @@ class JITFunction: class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type meta: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_stages: int + """ def __init__(self, meta, num_warps=4, num_stages=2): self.meta = meta self.num_warps = num_warps @@ -721,6 +734,35 @@ class Config: def autotune(configs, key, reset_to_zero=None): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + + :note: When all the configurations are evaluated, the kernel will run multiple time. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + reset the value of the provided tensor to `zero` before running any configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + """ def decorator(fn): def wrapper(kernel): return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero) @@ -732,6 +774,23 @@ def autotune(configs, key, reset_to_zero=None): def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + + + .param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + .type values: dict[str, Callable[[list[Any]], Any]] + """ def decorator(fn): def wrapper(kernel): def fun(*args, **meta): @@ -767,6 +826,8 @@ def jit(fn): return JITFunction(fn) +###### + def cdiv(x, y): return (x + y - 1) // y diff --git a/python/triton/language.py b/python/triton/language.py index 7255bf542..6cf84f733 100644 --- a/python/triton/language.py +++ b/python/triton/language.py @@ -49,16 +49,11 @@ for name in dir(frontend): def builtin(fn): @wraps(fn) def wrapper(*args, **kwargs): - if 'builder' not in kwargs or \ - kwargs['builder'] is None: - raise ValueError("Builder argument must be provided outside of JIT functions. Did you forget to add @triton.jit ?") + if '_builder' not in kwargs or \ + kwargs['_builder'] is None: + raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)") return fn(*args, **kwargs) - if wrapper.__doc__: - wrapper.__doc__ += """\ -:param builder: IR builder to generate code into - :type builder: triton.ir.builder, optional from within JIT'ed functions -""" return wrapper @@ -124,120 +119,120 @@ class block: self.dtype = block._init_dtype(self.handle.type.scalar) @builtin - def __add__(self, other, builder=None): - return frontend.add(self, other, builder) + def __add__(self, other, _builder=None): + return frontend.add(self, other, _builder) - def __radd__(self, other, builder=None): - return self.__add__(other, builder=builder) + def __radd__(self, other, _builder=None): + return self.__add__(other, _builder=_builder) @builtin - def __sub__(self, other, builder=None): - return frontend.sub(self, other, builder) + def __sub__(self, other, _builder=None): + return frontend.sub(self, other, _builder) - def __rsub__(self, other, builder=None): - return frontend.sub(other, self, builder) + def __rsub__(self, other, _builder=None): + return frontend.sub(other, self, _builder) @builtin - def __mul__(self, other, builder=None): - return frontend.mul(self, other, builder) + def __mul__(self, other, _builder=None): + return frontend.mul(self, other, _builder) - def __rmul__(self, other, builder=None): - return self.__mul__(other, builder=builder) + def __rmul__(self, other, _builder=None): + return self.__mul__(other, _builder=_builder) @builtin - def __truediv__(self, other, builder=None): - return frontend.truediv(self, other, builder) + def __truediv__(self, other, _builder=None): + return frontend.truediv(self, other, _builder) - def __rtruediv__(self, other, builder=None): - return frontend.truediv(other, self, builder) + def __rtruediv__(self, other, _builder=None): + return frontend.truediv(other, self, _builder) @builtin - def __floordiv__(self, other, builder=None): - return frontend.floordiv(self, other, builder) + def __floordiv__(self, other, _builder=None): + return frontend.floordiv(self, other, _builder) @builtin - def __mod__(self, other, builder=None): - return frontend.mod(self, other, builder) + def __mod__(self, other, _builder=None): + return frontend.mod(self, other, _builder) # unary operators @builtin - def __neg__(self, builder=None): - return frontend.minus(self, builder) + def __neg__(self, _builder=None): + return frontend.minus(self, _builder) @builtin - def __invert__(self, builder=None): - return frontend.invert(self, builder) + def __invert__(self, _builder=None): + return frontend.invert(self, _builder) # bitwise operators @builtin - def __and__(self, other, builder=None): - return frontend.and_(self, other, builder) + def __and__(self, other, _builder=None): + return frontend.and_(self, other, _builder) @builtin - def __or__(self, other, builder=None): - return frontend.or_(self, other, builder) + def __or__(self, other, _builder=None): + return frontend.or_(self, other, _builder) @builtin - def __xor__(self, other, builder=None): - return frontend.xor_(self, other, builder) + def __xor__(self, other, _builder=None): + return frontend.xor_(self, other, _builder) @builtin - def __lshift__(self, other, builder=None): - return frontend.shl(self, other, builder) + def __lshift__(self, other, _builder=None): + return frontend.shl(self, other, _builder) @builtin - def __rshift__(self, other, builder=None): - return frontend.lshr(self, other, builder) + def __rshift__(self, other, _builder=None): + return frontend.lshr(self, other, _builder) # comparison operators # > @builtin - def __gt__(self, other, builder=None): - return frontend.greater_than(self, other, builder) + def __gt__(self, other, _builder=None): + return frontend.greater_than(self, other, _builder) @builtin - def __rgt__(self, other, builder=None): - return frontend.greater_than(other, self, builder) + def __rgt__(self, other, _builder=None): + return frontend.greater_than(other, self, _builder) # >= @builtin - def __ge__(self, other, builder=None): - return frontend.greater_equal(self, other, builder) + def __ge__(self, other, _builder=None): + return frontend.greater_equal(self, other, _builder) - def __rge__(self, other, builder=None): - return frontend.greater_equal(other, self, builder) + def __rge__(self, other, _builder=None): + return frontend.greater_equal(other, self, _builder) # < @builtin - def __lt__(self, other, builder=None): - return frontend.less_than(self, other, builder) + def __lt__(self, other, _builder=None): + return frontend.less_than(self, other, _builder) @builtin - def __rlt__(self, other, builder=None): - return frontend.less_than(other, self, builder) + def __rlt__(self, other, _builder=None): + return frontend.less_than(other, self, _builder) # <= @builtin - def __le__(self, other, builder=None): - return frontend.less_equal(self, other, builder) + def __le__(self, other, _builder=None): + return frontend.less_equal(self, other, _builder) @builtin - def __rle__(self, other, builder=None): - return frontend.less_equal(other, self, builder) + def __rle__(self, other, _builder=None): + return frontend.less_equal(other, self, _builder) # == @builtin - def __eq__(self, other, builder=None): - return frontend.equal(self, other, builder) + def __eq__(self, other, _builder=None): + return frontend.equal(self, other, _builder) @builtin - def __ne__(self, other, builder=None): - return frontend.not_equal(self, other, builder) + def __ne__(self, other, _builder=None): + return frontend.not_equal(self, other, _builder) @builtin - def __getitem__(self, slices, builder=None): + def __getitem__(self, slices, _builder=None): if isinstance(slices, slice): slices = [slices] src_shape = self.shape @@ -249,15 +244,15 @@ class block: elif sl == slice(None, None, None): dst_shape.append(src_shape[curr]) curr += 1 - ret = frontend.reshape(self, dst_shape, builder) + ret = frontend.reshape(self, dst_shape, _builder) return ret @builtin - def to(self, dtype, bitcast=False, builder=None): - dtype = dtype.handle(builder) + def to(self, dtype, bitcast=False, _builder=None): + dtype = dtype.handle(_builder) if bitcast: - return frontend.bitcast(self, dtype, builder) - return frontend.cast(self, dtype, builder) + return frontend.bitcast(self, dtype, _builder) + return frontend.cast(self, dtype, _builder) # ----------------------- @@ -266,25 +261,25 @@ class block: @builtin -def program_id(axis, builder=None): +def program_id(axis, _builder=None): """ Returns the id of the current program instance along the given :code:`axis`. :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :type axis: int """ - return frontend.program_id(axis, builder) + return frontend.program_id(axis, _builder) @builtin -def num_programs(axis, builder=None): +def num_programs(axis, _builder=None): """ Returns the number of program instances launched along the given :code:`axis`. :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :type axis: int """ - return frontend.num_programs(axis, builder) + return frontend.num_programs(axis, _builder) # ----------------------- @@ -293,7 +288,7 @@ def num_programs(axis, builder=None): @builtin -def arange(start, end, builder=None): +def arange(start, end, _builder=None): """ Returns contiguous values within the open interval [:code:`start`, :code:`end`). @@ -302,11 +297,11 @@ def arange(start, end, builder=None): :param stop: End of the interval. Must be a power of two >= start. :type stop: int """ - return frontend.arange(start, end, builder) + return frontend.arange(start, end, _builder) @builtin -def zeros(shape, dtype, builder=None): +def zeros(shape, dtype, _builder=None): """ Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. @@ -316,7 +311,7 @@ def zeros(shape, dtype, builder=None): :type dtype: DType """ shape = [int(x.handle) if isinstance(x, block) else x for x in shape] - return frontend.zeros(shape, dtype, builder) + return frontend.zeros(shape, dtype, _builder) # ----------------------- @@ -325,7 +320,7 @@ def zeros(shape, dtype, builder=None): @builtin -def broadcast(input, other, builder=None): +def broadcast(input, other, _builder=None): """ Tries to broadcast the two given blocks to a common compatible shape. @@ -334,11 +329,11 @@ def broadcast(input, other, builder=None): :param other: The second input block. :type other: Block """ - return frontend.broadcast(input, other, builder) + return frontend.broadcast(input, other, _builder) @builtin -def broadcast_to(input, shape, builder=None): +def broadcast_to(input, shape, _builder=None): """ Tries to broadcast the given block to a new :code:`shape`. @@ -347,11 +342,11 @@ def broadcast_to(input, shape, builder=None): :param shape: The desired shape. :type shape: Tuple[int] """ - return frontend.broadcast_to(input, shape, builder) + return frontend.broadcast_to(input, shape, _builder) @builtin -def reshape(input, shape, builder=None): +def reshape(input, shape, _builder=None): """ Tries to reshape the given block to a new shape. @@ -361,7 +356,7 @@ def reshape(input, shape, builder=None): :type shape: Tuple[int] """ - return frontend.reshape(input, shape, builder) + return frontend.reshape(input, shape, _builder) # ----------------------- @@ -370,7 +365,7 @@ def reshape(input, shape, builder=None): @builtin -def dot(input, other, builder=None): +def dot(input, other, _builder=None): """ Returns the matrix product of two blocks. @@ -381,16 +376,16 @@ def dot(input, other, builder=None): :param other: The second block to be multiplied. :type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`} """ - return frontend.dot(input, other, builder) + return frontend.dot(input, other, _builder) # ----------------------- -# Memory Operations +# Non-Atomic Memory Operations # ----------------------- @builtin -def load(pointer, mask=None, other=None, builder=None): +def load(pointer, mask=None, other=None, _builder=None): """ Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. @@ -405,11 +400,11 @@ def load(pointer, mask=None, other=None, builder=None): :param other: if mask[idx] is false, return other[idx] :type other: Block, optional """ - return frontend.load(pointer, mask, other, builder) + return frontend.load(pointer, mask, other, _builder) @builtin -def store(pointer, value, mask=None, builder=None): +def store(pointer, value, mask=None, _builder=None): """ Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. @@ -422,13 +417,20 @@ def store(pointer, value, mask=None, builder=None): :param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`. :type mask: Block of triton.int1, optional """ - return frontend.store(pointer, value, mask, builder) + return frontend.store(pointer, value, mask, _builder) -@builtin -def atomic_cas(pointer, cmp, val, builder=None): - """ - Performs an atomic "compare-and-swap" and the memory locations specified by :code:`pointer`. +# ----------------------- +# Atomic Memory Operations +# ----------------------- + +def _add_atomic_docstr(name): + + def _decorator(func): + docstr = """ + Performs an atomic {name} at the memory location specified by :code:`pointer`. + + Return the data stored at :code:`pointer` before the atomic operation. :param pointer: The memory locations to compare-and-swap. :type pointer: Block of dtype=triton.PointerDType @@ -437,62 +439,56 @@ def atomic_cas(pointer, cmp, val, builder=None): :param val: The values to copy in case the expected value matches the contained value. :type val: Block of dtype=`pointer.dtype.element_ty` """ - - return frontend.atomic_cas(pointer, cmp, val, builder) + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + +@builtin +@_add_atomic_docstr("compare-and-swap") +def atomic_cas(pointer, cmp, val, _builder=None): + return frontend.atomic_cas(pointer, cmp, val, _builder) @builtin -def atomic_xchg(pointer, val, mask=None, builder=None): - """ - Swaps the *old* values stored at location :code:`pointer` with the new values given by :code:`val`. Returns the old values. +@_add_atomic_docstr("exchange") +def atomic_xchg(pointer, val, mask=None, _builder=None): + return frontend.atomic_xchg(pointer, val, mask, _builder) - :param pointer: The memory locations which contain the old values - :type pointer: Block of dtype=triton.PointerDType - :param val: The new values to store - :type val: Block of dtype=`pointer.dtype.element_ty` - :param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected. - :type mask: Block of triton.int1, optional - """ - return frontend.atomic_xchg(pointer, val, mask, builder) +@builtin +@_add_atomic_docstr("add") +def atomic_add(pointer, val, mask=None, _builder=None): + return frontend.atomic_add(pointer, val, mask, _builder) @builtin -def atomic_add(pointer, val, mask=None, builder=None): - """ - Performs an atomic add and the memory locations specified by :code:`pointer`. - :param pointer: The memory locations which contain the old values - :type pointer: Block of dtype=triton.PointerDType - :param val: The values to add - :type val: Block of dtype=`pointer.dtype.element_ty` - :param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected. - :type mask: Block of triton.int1, optional - """ - return frontend.atomic_add(pointer, val, mask, builder) +@_add_atomic_docstr("max") +def atomic_max(pointer, val, mask=None, _builder=None): + return frontend.atomic_max(pointer, val, mask, _builder) @builtin -def atomic_max(pointer, val, mask=None, builder=None): - return frontend.atomic_max(pointer, val, mask, builder) +@_add_atomic_docstr("min") +def atomic_min(pointer, val, mask=None, _builder=None): + return frontend.atomic_min(pointer, val, mask, _builder) @builtin -def atomic_min(pointer, val, mask=None, builder=None): - return frontend.atomic_min(pointer, val, mask, builder) +@_add_atomic_docstr("logical and") +def atomic_and(pointer, val, mask=None, _builder=None): + return frontend.atomic_and(pointer, val, mask, _builder) @builtin -def atomic_and(pointer, val, mask=None, builder=None): - return frontend.atomic_and(pointer, val, mask, builder) +@_add_atomic_docstr("logical or") +def atomic_or(pointer, val, mask=None, _builder=None): + return frontend.atomic_or(pointer, val, mask, _builder) @builtin -def atomic_or(pointer, val, mask=None, builder=None): - return frontend.atomic_or(pointer, val, mask, builder) - - -@builtin -def atomic_xor(pointer, val, mask=None, builder=None): - return frontend.atomic_xor(pointer, val, mask, builder) +@_add_atomic_docstr("logical xor") +def atomic_xor(pointer, val, mask=None, _builder=None): + return frontend.atomic_xor(pointer, val, mask, _builder) # ----------------------- @@ -501,7 +497,7 @@ def atomic_xor(pointer, val, mask=None, builder=None): @builtin -def where(condition, x, y, builder=None): +def where(condition, x, y, _builder=None): """ Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. @@ -517,108 +513,89 @@ def where(condition, x, y, builder=None): :param x: values selected at indices where condition is True. :param y: values selected at indices where condition is False. """ - return frontend.where(condition, x, y, builder) + return frontend.where(condition, x, y, _builder) # ----------------------- # Math # ----------------------- +def _add_math_1arg_docstr(name): -@builtin -def exp(x, builder=None): - """ - Computes the element-wise exponential of :code:`x` + def _decorator(func): + docstr = """ + Computes the element-wise {name} of :code:`x` :param x: the input values :type x: Block """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator - return frontend.exp(x, builder) +@builtin +@_add_math_1arg_docstr("exponential") +def exp(x, _builder=None): + return frontend.exp(x, _builder) @builtin -def log(x, builder=None): - """ - Computes the element-wise natural logarithm of :code:`x` - - :param x: the input values - :type x: Block - """ - - return frontend.log(x, builder) +@_add_math_1arg_docstr("natural logarithm") +def log(x, _builder=None): + return frontend.log(x, _builder) @builtin -def cos(x, builder=None): - """ - Computes the element-wise cosine of :code:`x` - - :param x: the input values - :type x: Block - """ - - return frontend.cos(x, builder) +@_add_math_1arg_docstr("cosine") +def cos(x, _builder=None): + return frontend.cos(x, _builder) @builtin -def sin(x, builder=None): - """ - Computes the element-wise sine of :code:`x` - - :param x: the input values - :type x: Block - """ - - return frontend.sin(x, builder) +@_add_math_1arg_docstr("sine") +def sin(x, _builder=None): + return frontend.sin(x, _builder) @builtin -def sqrt(x, builder=None): - """ - Computes the element-wise square root of :code:`x` - - :param x: the input values - :type x: Block - """ - return frontend.sqrt(x, builder) +@_add_math_1arg_docstr("square root") +def sqrt(x, _builder=None): + return frontend.sqrt(x, _builder) # ----------------------- # Reductions # ----------------------- +def _add_reduction_docstr(name): -@builtin -def max(input, axis, builder=None): - """ - Returns the maximum value of all elements in the :code:`input` block along the provided :code:`axis` + def _decorator(func): + docstr = """ + Returns the {name} of all elements in the :code:`input` block along the provided :code:`axis` :param input: the input values :param axis: the dimension along which the reduction should be done """ - return frontend.max(input, axis, builder) + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + +@builtin +@_add_reduction_docstr("maximum") +def max(input, axis, _builder=None): + return frontend.max(input, axis, _builder) @builtin -def min(input, axis, builder=None): - """ - Returns the minimum value of all elements in the :code:`input` block along the provided :code:`axis` - - :param input: the input values - :param axis: the dimension along which the reduction should be done - """ - return frontend.min(input, axis, builder) +@_add_reduction_docstr("minimum") +def min(input, axis, _builder=None): + return frontend.min(input, axis, _builder) @builtin -def sum(input, axis, builder=None): - """ - Returns the sum of all elements in the :code:`input` block along the provided :code:`axis` - - :param input: the input values - :param axis: the dimension along which the reduction should be done - """ - - return frontend.sum(input, axis, builder) +@_add_reduction_docstr("sum") +def sum(input, axis, _builder=None): + return frontend.sum(input, axis, _builder) # ----------------------- @@ -627,24 +604,24 @@ def sum(input, axis, builder=None): @builtin -def debug_barrier(builder=None): - return frontend.debug_barrier(builder) +def debug_barrier(_builder=None): + return frontend.debug_barrier(_builder) @builtin -def multiple_of(input, value, builder=None): +def multiple_of(input, value, _builder=None): """ Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. """ - return frontend.multiple_of(input, value, builder) + return frontend.multiple_of(input, value, _builder) @builtin -def max_contiguous(input, value, builder=None): +def max_contiguous(input, value, _builder=None): """ Let the compiler knows that the `value` first values in :code:`input` are contiguous. """ - return frontend.max_contiguous(input, value, builder) + return frontend.max_contiguous(input, value, _builder) # ----------------------- @@ -690,24 +667,14 @@ def maximum(x, y): @triton.jit +@_add_math_1arg_docstr("sigmoid") def sigmoid(x): - """ - Computes the element-wise sigmoid of :code:`x`. - - :param x: the input block - :type x: Block - """ return 1 / (1 + triton.language.exp(-x)) @triton.jit +@_add_math_1arg_docstr("softmax") def softmax(x): - """ - Computes the element-wise softmax of :code:`x`. - - :param x: the input block - :type x: Block - """ z = x - triton.language.max(x, 0) num = triton.language.exp(z) den = triton.language.sum(num, 0) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 80207d8cf..1d9fea638 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -322,7 +322,7 @@ else: triton.testing.Benchmark( x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot x_vals=[ - 128 * i for i in range(1, 33) + 128 * i for i in range(2, 33) ], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg``