From 29e33e50b76e6b04b644f5d7da1c613a30b12937 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 22 Apr 2021 10:27:02 -0400 Subject: [PATCH] [DOCS] Updates and improvements (#87) --- docs/conf.py | 35 +++- python/setup.py | 2 +- python/triton/__init__.py | 1 + python/triton/code_gen.py | 2 + python/triton/core.py | 204 ++++++++++++++----- python/tutorials/01-vector-add.py | 5 +- python/tutorials/02-fused-softmax.py | 2 +- python/tutorials/03-matrix-multiplication.py | 14 +- 8 files changed, 195 insertions(+), 70 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 5c343cec0..7a8d8c3f5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -23,17 +23,34 @@ # -- General configuration ------------------------------------------------ -# If your documentation needs a minimal Sphinx version, state it here. -# -# needs_sphinx = '1.0' -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [] +def setup(app): + """Customize function args retrieving to get args under decorator.""" + import sphinx + import triton + + def forward_jit_fn(func): + old = func + + def wrapped(obj, **kwargs): + if isinstance(obj, triton.code_gen.JITFunction): + obj = obj.fn + return old(obj) + + return wrapped + + old_documenter = sphinx.ext.autosummary.get_documenter + + def documenter(app, obj, parent): + if isinstance(obj, triton.code_gen.JITFunction): + obj = obj.fn + return old_documenter(app, obj, parent) + + sphinx.ext.autosummary.get_documenter = documenter + sphinx.util.inspect.unwrap_all = forward_jit_fn(sphinx.util.inspect.unwrap_all) + sphinx.util.inspect.signature = forward_jit_fn(sphinx.util.inspect.signature) + sphinx.util.inspect.object_description = forward_jit_fn(sphinx.util.inspect.object_description) -# Math Jax -extensions += ['sphinx.ext.mathjax'] # Auto Doc import sys diff --git a/python/setup.py b/python/setup.py index c2a35b32c..636054b0e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -53,7 +53,7 @@ class CMakeBuild(build_ext): extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # create build directories build_suffix = 'debug' if self.debug else 'release' - llvm_build_dir = os.path.join(tempfile.gettempdir(), f"llvm-{build_suffix}") + llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm-" + build_suffix) if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) if not os.path.exists(llvm_build_dir): diff --git a/python/triton/__init__.py b/python/triton/__init__.py index ebea7548b..663a9c2df 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -5,6 +5,7 @@ import torch from .code_gen import jit, autotune, heuristics, Config, Autotuner from .core import * +from . import code_gen from . import testing from . import ops # version diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index b1a2eff11..a4c42070f 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -573,12 +573,14 @@ class Autotuner: class JITFunction: def __init__(self, fn): + self.fn = fn self.module = fn.__module__ self.arg_names = inspect.getfullargspec(fn).args self.cache = dict() self.kernel_decorators = [] self.src = textwrap.dedent(inspect.getsource(fn)) self.kernel = None + self.__doc__ = fn.__doc__ # we do not parse in the constructor because # the user might want to monkey-patch self.src dynamically. diff --git a/python/triton/core.py b/python/triton/core.py index 8325990c1..a84435df5 100644 --- a/python/triton/core.py +++ b/python/triton/core.py @@ -56,8 +56,8 @@ def builtin(fn): if wrapper.__doc__: wrapper.__doc__ += """\ -:param builder: IR builder to generate code into, optional from within @triton.jit functions - :type builder: triton.ir.builder +:param builder: IR builder to generate code into + :type builder: triton.ir.builder, optional from within JIT'ed functions """ return wrapper @@ -236,8 +236,7 @@ class block: @builtin def program_id(axis, builder=None): """ - Returns the id of the current program instance along the given `axis`. - Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s. + Returns the id of the current program instance along the given :code:`axis`. :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :type axis: int @@ -248,7 +247,7 @@ def program_id(axis, builder=None): @builtin def num_programs(axis, builder=None): """ - Returns the number of program instances launched along the given `axis`. + Returns the number of program instances launched along the given :code:`axis`. :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :type axis: int @@ -264,11 +263,11 @@ def num_programs(axis, builder=None): @builtin def arange(start, end, builder=None): """ - Returns contiguous values within the open interval [start, end). + Returns contiguous values within the open interval [:code:`start`, :code:`end`). - :param start: Start of the interval. + :param start: Start of the interval. Must be a power of two. :type start: int - :param stop: End of the interval. + :param stop: End of the interval. Must be a power of two >= start. :type stop: int """ return frontend.arange(start, end, builder) @@ -277,12 +276,12 @@ def arange(start, end, builder=None): @builtin def zeros(shape, dtype, builder=None): """ - Returns a block filled with the scalar value 0 and the given shape. + Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. :param shape: Shape of the new array, e.g., (8, 16) or (8, ) :type shape: tuple of ints - :param dtype: Data-type of the new array, e.g., triton.float16 - :type dtype: triton.ir.dtype + :param dtype: Data-type of the new array, e.g., :code:`triton.float16` + :type dtype: DType """ return frontend.zeros(shape, dtype, builder) @@ -295,12 +294,12 @@ def zeros(shape, dtype, builder=None): @builtin def broadcast(input, other, builder=None): """ - Tries to broadcast two blocks to a common compatible shape. + Tries to broadcast the two given blocks to a common compatible shape. :param input: The first input block. - :type input: triton.ir.value + :type input: Block :param other: The second input block. - :type other: triton.ir.value + :type other: Block """ return frontend.broadcast(input, other, builder) @@ -308,12 +307,12 @@ def broadcast(input, other, builder=None): @builtin def broadcast_to(input, shape, builder=None): """ - Tries to broadcast a block to a new shape. + Tries to broadcast the given block to a new :code:`shape`. :param input: The input block. - :type input: triton.value - :param shape: The new shape. - :type shape: tuple of int + :type input: Block + :param shape: The desired shape. + :type shape: Tuple[int] """ return frontend.broadcast_to(input, shape, builder) @@ -321,7 +320,13 @@ def broadcast_to(input, shape, builder=None): @builtin def reshape(input, shape, builder=None): """ - Reshapes a block to a new shape. + Tries to reshape the given block to a new shape. + + :param input: The input block. + :type input: + :param shape: The desired shape. + :type shape: Tuple[int] + """ return frontend.reshape(input, shape, builder) @@ -335,12 +340,13 @@ def reshape(input, shape, builder=None): def dot(input, other, builder=None): """ Returns the matrix product of two blocks. + The two blocks must be two dimensionals and have compatible inner dimensions. :param input: The first block to be multiplied. - :type input: 2D block of scalar-type in {`float16`, `float32`} + :type input: 2D block of scalar-type in {:code:`float16`, :code:`float32`} :param other: The second block to be multiplied. - :type other: 2D block of scalar-type in {`float16`, `float32`} + :type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`} """ return frontend.dot(input, other, builder) @@ -353,14 +359,18 @@ def dot(input, other, builder=None): @builtin def load(pointer, mask=None, other=None, builder=None): """ - Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`. + Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. - :param pointer: Pointer to the data to be loaded. - :type pointer: Block of triton.pointer - :param mask: if mask[idx] is false, do not load the data at `pointer[idx]`. - :type mask: Block of triton.bool, optional - :param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]` - :type other: Block of triton.value, optional + :code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`. + + :code:`other` is implicitly typecast to :code:`pointer.dtype.element_ty`. + + :param pointer: Pointers to the data to be loaded. + :type pointer: Block of dtype=triton.PointerDType + :param mask: if mask[idx] is false, do not load the data at address :code:`pointer[idx]`. + :type mask: Block of triton.int1, optional + :param other: if mask[idx] is false, return other[idx] + :type other: Block, optional """ return frontend.load(pointer, mask, other, builder) @@ -368,26 +378,47 @@ def load(pointer, mask=None, other=None, builder=None): @builtin def store(pointer, value, mask=None, builder=None): """ - Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`. + Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. - :param pointer: The memory locations where the elements of `value` are stored. - :type pointer: Block of triton.pointer + :code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`. + + :param pointer: The memory locations where the elements of :code:`value` are stored. + :type pointer: Block of dtype=triton.PointerDType :param value: The block of elements to be stored. - :type value: Block of triton.value - :param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`. - :type mask: Block of triton.bool, optional + :type value: Block + :param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`. + :type mask: Block of triton.int1, optional """ return frontend.store(pointer, value, mask, builder) @builtin -def atomic_cas(ptr, cmp, val, builder=None): - return frontend.atomic_cas(ptr, cmp, val, builder) +def atomic_cas(pointer, cmp, val, builder=None): + """ + Performs an atomic "compare-and-swap" and the memory locations specified by :code:`pointer`. + + :param pointer: The memory locations to compare-and-swap. + :type pointer: Block of dtype=triton.PointerDType + :param cmp: The values expected to be found in the atomic object + :type cmp: Block of dtype=`pointer.dtype.element_ty` + :param val: The values to copy in case the expected value matches the contained value. + :type val: Block of dtype=`pointer.dtype.element_ty` + """ + + return frontend.atomic_cas(pointer, cmp, val, builder) @builtin -def atomic_xchg(ptr, val, builder=None): - return frontend.atomic_xchg(ptr, val, builder) +def atomic_xchg(pointer, val, builder=None): + """ + Swaps the *old* values stored at location :code:`pointer` with the new values given by :code:`val`. Returns the old values. + + :param pointer: The memory locations which contain the old values + :type pointer: Block of dtype=triton.PointerDType + :param val: The new values to store + :type val: Block of dtype=`pointer.dtype.element_ty` + """ + return frontend.atomic_xchg(pointer, val, builder) # ----------------------- @@ -398,11 +429,14 @@ def atomic_xchg(ptr, val, builder=None): @builtin def where(condition, x, y, builder=None): """ - Returns a block of elements from either `x` or `y`, depending on `condition`. - Note that `x` and `y` are always evaluated regardless of the value of `condition`. - If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead. - The shape of `x` and `y` are both broadcast to the shape of `condition`. - `x` and `y` must have the data type. + Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + + Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. + + If you want to avoid unintented memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead. + + The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`. + :code:`x` and :code:`y` must have the data type. :param condition: When True (nonzero), yield x, otherwise yield y. :type condition: Block of triton.bool @@ -419,11 +453,25 @@ def where(condition, x, y, builder=None): @builtin def exp(x, builder=None): + """ + Computes the element-wise exponential of :code:`x` + + :param x: the input values + :type x: Block + """ + return frontend.exp(x, builder) @builtin def log(x, builder=None): + """ + Computes the element-wise natural logarithm of :code:`x` + + :param x: the input values + :type x: Block + """ + return frontend.log(x, builder) @@ -434,16 +482,35 @@ def log(x, builder=None): @builtin def max(input, axis, builder=None): + """ + Returns the maximum value of all elements in the :code:`input` block along the provided :code:`axis` + + :param input: the input values + :param axis: the dimension along which the reduction should be done + """ return frontend.max(input, axis, builder) @builtin def min(input, axis, builder=None): + """ + Returns the minimum value of all elements in the :code:`input` block along the provided :code:`axis` + + :param input: the input values + :param axis: the dimension along which the reduction should be done + """ return frontend.min(input, axis, builder) @builtin def sum(input, axis, builder=None): + """ + Returns the sum of all elements in the :code:`input` block along the provided :code:`axis` + + :param input: the input values + :param axis: the dimension along which the reduction should be done + """ + return frontend.sum(input, axis, builder) @@ -458,8 +525,11 @@ def debug_barrier(builder=None): @builtin -def multiple_of(x, value, builder=None): - return frontend.multiple_of(x, value, builder) +def multiple_of(input, value, builder=None): + """ + Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. + """ + return frontend.multiple_of(input, value, builder) # ----------------------- @@ -469,31 +539,65 @@ def multiple_of(x, value, builder=None): @triton.jit def minimum(x, y): + """ + Computes the element-wise minimum of :code:`x` and :code:`y`. + + :param input: the first input block + :type input: Block + :param other: the second input block + :type other: Block + """ return triton.where(x < y, x, y) @triton.jit def maximum(x, y): + """ + Computes the element-wise maximum of :code:`x` and :code:`y`. + + :param input: the first input block + :type input: Block + :param other: the second input block + :type other: Block + """ return triton.where(x > y, x, y) @triton.jit def sigmoid(x): + """ + Computes the element-wise sigmoid of :code:`x`. + + :param x: the input block + :type x: Block + """ return 1 / (1 + np.exp(-x)) -@triton.jit -def ravel(x): - return triton.reshape(x, [x.type.numel]) - - @triton.jit def softmax(x): + """ + Computes the element-wise softmax of :code:`x`. + + :param x: the input block + :type x: Block + """ z = x - triton.max(x, 0) num = triton.exp(z) den = triton.sum(num, 0) return num / den +@triton.jit +def ravel(x): + """ + Returns a contiguous flattened view of :code:`x` + + :param x: the input block + :type x: Block + """ + return triton.reshape(x, [x.type.numel]) + + def cdiv(x, y): return (x + y - 1) // y diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index b1c82a34a..29797d579 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -1,5 +1,3 @@ -import torch -import triton """ Vector Addition ================= @@ -14,6 +12,9 @@ In this tutorial, you will write a simple vector addition using Triton and learn # Compute Kernel # -------------------------- +import torch +import triton + @triton.jit def _add( diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index 3c5d674c2..f9b1b5103 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -100,7 +100,7 @@ def softmax(x): # Allocate output y = torch.empty_like(x) # Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix - _softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, BLOCK=BLOCK) + _softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, num_warps=num_warps, BLOCK=BLOCK) return y diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 8f65a91b4..0b7db9387 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -56,7 +56,7 @@ You will specifically learn about: # Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as: # # .. code-block:: python -# :force: +# # pid_m = triton.program_id(0) # pid_n = triton.program_id(1) # rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) @@ -85,8 +85,8 @@ You will specifically learn about: # .. code-block:: Python # # pid = triton.program_id(0); -# grid_m = (M + BLOCK_M - 1) / BLOCK_M; -# grid_n = (N + BLOCK_N - 1) / BLOCK_N; +# grid_m = (M + BLOCK_M - 1) // BLOCK_M; +# grid_n = (N + BLOCK_N - 1) // BLOCK_N; # pid_m = pid / grid_n; # pid_n = pid % grid_n; # @@ -95,15 +95,15 @@ You will specifically learn about: # One possible solution is to launch blocks in an order that promotes data reuse. # This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column: # -# .. code-block:: C +# .. code-block:: python # # pid = triton.program_id(0); # width = GROUP_M * grid_n; -# group_id = pid / width; +# group_id = pid // width; # # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0 # group_size = min(grid_m - group_id * GROUP_M, GROUP_M); # pid_m = group_id * GROUP_M + (pid % group_size); -# pid_n = (pid % width) / (group_size); +# pid_n = (pid % width) // (group_size); # # In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100). # @@ -237,7 +237,7 @@ print(triton.testing.allclose(c_0, c_1)) @triton.testing.perf_report( triton.testing.Benchmark( x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot - x_vals=[8192], # different possible values for `x_name` + x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name` y_name='provider', # argument name whose value corresponds to a different line in the plot y_vals=['cublas', 'triton'], # possible keys for `y_name` y_lines=["cuBLAS", "Triton"], # label name for the lines