[DOCS] Updates and improvements (#87)

This commit is contained in:
Philippe Tillet
2021-04-22 10:27:02 -04:00
committed by Philippe Tillet
parent 39f4730305
commit 29e33e50b7
8 changed files with 195 additions and 70 deletions

View File

@@ -23,17 +23,34 @@
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be def setup(app):
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom """Customize function args retrieving to get args under decorator."""
# ones. import sphinx
extensions = [] import triton
def forward_jit_fn(func):
old = func
def wrapped(obj, **kwargs):
if isinstance(obj, triton.code_gen.JITFunction):
obj = obj.fn
return old(obj)
return wrapped
old_documenter = sphinx.ext.autosummary.get_documenter
def documenter(app, obj, parent):
if isinstance(obj, triton.code_gen.JITFunction):
obj = obj.fn
return old_documenter(app, obj, parent)
sphinx.ext.autosummary.get_documenter = documenter
sphinx.util.inspect.unwrap_all = forward_jit_fn(sphinx.util.inspect.unwrap_all)
sphinx.util.inspect.signature = forward_jit_fn(sphinx.util.inspect.signature)
sphinx.util.inspect.object_description = forward_jit_fn(sphinx.util.inspect.object_description)
# Math Jax
extensions += ['sphinx.ext.mathjax']
# Auto Doc # Auto Doc
import sys import sys

View File

@@ -53,7 +53,7 @@ class CMakeBuild(build_ext):
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# create build directories # create build directories
build_suffix = 'debug' if self.debug else 'release' build_suffix = 'debug' if self.debug else 'release'
llvm_build_dir = os.path.join(tempfile.gettempdir(), f"llvm-{build_suffix}") llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm-" + build_suffix)
if not os.path.exists(self.build_temp): if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp) os.makedirs(self.build_temp)
if not os.path.exists(llvm_build_dir): if not os.path.exists(llvm_build_dir):

View File

@@ -5,6 +5,7 @@ import torch
from .code_gen import jit, autotune, heuristics, Config, Autotuner from .code_gen import jit, autotune, heuristics, Config, Autotuner
from .core import * from .core import *
from . import code_gen
from . import testing from . import testing
from . import ops from . import ops
# version # version

View File

@@ -573,12 +573,14 @@ class Autotuner:
class JITFunction: class JITFunction:
def __init__(self, fn): def __init__(self, fn):
self.fn = fn
self.module = fn.__module__ self.module = fn.__module__
self.arg_names = inspect.getfullargspec(fn).args self.arg_names = inspect.getfullargspec(fn).args
self.cache = dict() self.cache = dict()
self.kernel_decorators = [] self.kernel_decorators = []
self.src = textwrap.dedent(inspect.getsource(fn)) self.src = textwrap.dedent(inspect.getsource(fn))
self.kernel = None self.kernel = None
self.__doc__ = fn.__doc__
# we do not parse in the constructor because # we do not parse in the constructor because
# the user might want to monkey-patch self.src dynamically. # the user might want to monkey-patch self.src dynamically.

View File

@@ -56,8 +56,8 @@ def builtin(fn):
if wrapper.__doc__: if wrapper.__doc__:
wrapper.__doc__ += """\ wrapper.__doc__ += """\
:param builder: IR builder to generate code into, optional from within @triton.jit functions :param builder: IR builder to generate code into
:type builder: triton.ir.builder :type builder: triton.ir.builder, optional from within JIT'ed functions
""" """
return wrapper return wrapper
@@ -236,8 +236,7 @@ class block:
@builtin @builtin
def program_id(axis, builder=None): def program_id(axis, builder=None):
""" """
Returns the id of the current program instance along the given `axis`. Returns the id of the current program instance along the given :code:`axis`.
Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int :type axis: int
@@ -248,7 +247,7 @@ def program_id(axis, builder=None):
@builtin @builtin
def num_programs(axis, builder=None): def num_programs(axis, builder=None):
""" """
Returns the number of program instances launched along the given `axis`. Returns the number of program instances launched along the given :code:`axis`.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int :type axis: int
@@ -264,11 +263,11 @@ def num_programs(axis, builder=None):
@builtin @builtin
def arange(start, end, builder=None): def arange(start, end, builder=None):
""" """
Returns contiguous values within the open interval [start, end). Returns contiguous values within the open interval [:code:`start`, :code:`end`).
:param start: Start of the interval. :param start: Start of the interval. Must be a power of two.
:type start: int :type start: int
:param stop: End of the interval. :param stop: End of the interval. Must be a power of two >= start.
:type stop: int :type stop: int
""" """
return frontend.arange(start, end, builder) return frontend.arange(start, end, builder)
@@ -277,12 +276,12 @@ def arange(start, end, builder=None):
@builtin @builtin
def zeros(shape, dtype, builder=None): def zeros(shape, dtype, builder=None):
""" """
Returns a block filled with the scalar value 0 and the given shape. Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
:param shape: Shape of the new array, e.g., (8, 16) or (8, ) :param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints :type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., triton.float16 :param dtype: Data-type of the new array, e.g., :code:`triton.float16`
:type dtype: triton.ir.dtype :type dtype: DType
""" """
return frontend.zeros(shape, dtype, builder) return frontend.zeros(shape, dtype, builder)
@@ -295,12 +294,12 @@ def zeros(shape, dtype, builder=None):
@builtin @builtin
def broadcast(input, other, builder=None): def broadcast(input, other, builder=None):
""" """
Tries to broadcast two blocks to a common compatible shape. Tries to broadcast the two given blocks to a common compatible shape.
:param input: The first input block. :param input: The first input block.
:type input: triton.ir.value :type input: Block
:param other: The second input block. :param other: The second input block.
:type other: triton.ir.value :type other: Block
""" """
return frontend.broadcast(input, other, builder) return frontend.broadcast(input, other, builder)
@@ -308,12 +307,12 @@ def broadcast(input, other, builder=None):
@builtin @builtin
def broadcast_to(input, shape, builder=None): def broadcast_to(input, shape, builder=None):
""" """
Tries to broadcast a block to a new shape. Tries to broadcast the given block to a new :code:`shape`.
:param input: The input block. :param input: The input block.
:type input: triton.value :type input: Block
:param shape: The new shape. :param shape: The desired shape.
:type shape: tuple of int :type shape: Tuple[int]
""" """
return frontend.broadcast_to(input, shape, builder) return frontend.broadcast_to(input, shape, builder)
@@ -321,7 +320,13 @@ def broadcast_to(input, shape, builder=None):
@builtin @builtin
def reshape(input, shape, builder=None): def reshape(input, shape, builder=None):
""" """
Reshapes a block to a new shape. Tries to reshape the given block to a new shape.
:param input: The input block.
:type input:
:param shape: The desired shape.
:type shape: Tuple[int]
""" """
return frontend.reshape(input, shape, builder) return frontend.reshape(input, shape, builder)
@@ -335,12 +340,13 @@ def reshape(input, shape, builder=None):
def dot(input, other, builder=None): def dot(input, other, builder=None):
""" """
Returns the matrix product of two blocks. Returns the matrix product of two blocks.
The two blocks must be two dimensionals and have compatible inner dimensions. The two blocks must be two dimensionals and have compatible inner dimensions.
:param input: The first block to be multiplied. :param input: The first block to be multiplied.
:type input: 2D block of scalar-type in {`float16`, `float32`} :type input: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
:param other: The second block to be multiplied. :param other: The second block to be multiplied.
:type other: 2D block of scalar-type in {`float16`, `float32`} :type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
""" """
return frontend.dot(input, other, builder) return frontend.dot(input, other, builder)
@@ -353,14 +359,18 @@ def dot(input, other, builder=None):
@builtin @builtin
def load(pointer, mask=None, other=None, builder=None): def load(pointer, mask=None, other=None, builder=None):
""" """
Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`. Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
:param pointer: Pointer to the data to be loaded. :code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`.
:type pointer: Block of triton.pointer
:param mask: if mask[idx] is false, do not load the data at `pointer[idx]`. :code:`other` is implicitly typecast to :code:`pointer.dtype.element_ty`.
:type mask: Block of triton.bool, optional
:param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]` :param pointer: Pointers to the data to be loaded.
:type other: Block of triton.value, optional :type pointer: Block of dtype=triton.PointerDType
:param mask: if mask[idx] is false, do not load the data at address :code:`pointer[idx]`.
:type mask: Block of triton.int1, optional
:param other: if mask[idx] is false, return other[idx]
:type other: Block, optional
""" """
return frontend.load(pointer, mask, other, builder) return frontend.load(pointer, mask, other, builder)
@@ -368,26 +378,47 @@ def load(pointer, mask=None, other=None, builder=None):
@builtin @builtin
def store(pointer, value, mask=None, builder=None): def store(pointer, value, mask=None, builder=None):
""" """
Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`. Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
:param pointer: The memory locations where the elements of `value` are stored. :code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`.
:type pointer: Block of triton.pointer
:param pointer: The memory locations where the elements of :code:`value` are stored.
:type pointer: Block of dtype=triton.PointerDType
:param value: The block of elements to be stored. :param value: The block of elements to be stored.
:type value: Block of triton.value :type value: Block
:param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`. :param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`.
:type mask: Block of triton.bool, optional :type mask: Block of triton.int1, optional
""" """
return frontend.store(pointer, value, mask, builder) return frontend.store(pointer, value, mask, builder)
@builtin @builtin
def atomic_cas(ptr, cmp, val, builder=None): def atomic_cas(pointer, cmp, val, builder=None):
return frontend.atomic_cas(ptr, cmp, val, builder) """
Performs an atomic "compare-and-swap" and the memory locations specified by :code:`pointer`.
:param pointer: The memory locations to compare-and-swap.
:type pointer: Block of dtype=triton.PointerDType
:param cmp: The values expected to be found in the atomic object
:type cmp: Block of dtype=`pointer.dtype.element_ty`
:param val: The values to copy in case the expected value matches the contained value.
:type val: Block of dtype=`pointer.dtype.element_ty`
"""
return frontend.atomic_cas(pointer, cmp, val, builder)
@builtin @builtin
def atomic_xchg(ptr, val, builder=None): def atomic_xchg(pointer, val, builder=None):
return frontend.atomic_xchg(ptr, val, builder) """
Swaps the *old* values stored at location :code:`pointer` with the new values given by :code:`val`. Returns the old values.
:param pointer: The memory locations which contain the old values
:type pointer: Block of dtype=triton.PointerDType
:param val: The new values to store
:type val: Block of dtype=`pointer.dtype.element_ty`
"""
return frontend.atomic_xchg(pointer, val, builder)
# ----------------------- # -----------------------
@@ -398,11 +429,14 @@ def atomic_xchg(ptr, val, builder=None):
@builtin @builtin
def where(condition, x, y, builder=None): def where(condition, x, y, builder=None):
""" """
Returns a block of elements from either `x` or `y`, depending on `condition`. Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
Note that `x` and `y` are always evaluated regardless of the value of `condition`.
If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead. Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`.
The shape of `x` and `y` are both broadcast to the shape of `condition`.
`x` and `y` must have the data type. If you want to avoid unintented memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead.
The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`.
:code:`x` and :code:`y` must have the data type.
:param condition: When True (nonzero), yield x, otherwise yield y. :param condition: When True (nonzero), yield x, otherwise yield y.
:type condition: Block of triton.bool :type condition: Block of triton.bool
@@ -419,11 +453,25 @@ def where(condition, x, y, builder=None):
@builtin @builtin
def exp(x, builder=None): def exp(x, builder=None):
"""
Computes the element-wise exponential of :code:`x`
:param x: the input values
:type x: Block
"""
return frontend.exp(x, builder) return frontend.exp(x, builder)
@builtin @builtin
def log(x, builder=None): def log(x, builder=None):
"""
Computes the element-wise natural logarithm of :code:`x`
:param x: the input values
:type x: Block
"""
return frontend.log(x, builder) return frontend.log(x, builder)
@@ -434,16 +482,35 @@ def log(x, builder=None):
@builtin @builtin
def max(input, axis, builder=None): def max(input, axis, builder=None):
"""
Returns the maximum value of all elements in the :code:`input` block along the provided :code:`axis`
:param input: the input values
:param axis: the dimension along which the reduction should be done
"""
return frontend.max(input, axis, builder) return frontend.max(input, axis, builder)
@builtin @builtin
def min(input, axis, builder=None): def min(input, axis, builder=None):
"""
Returns the minimum value of all elements in the :code:`input` block along the provided :code:`axis`
:param input: the input values
:param axis: the dimension along which the reduction should be done
"""
return frontend.min(input, axis, builder) return frontend.min(input, axis, builder)
@builtin @builtin
def sum(input, axis, builder=None): def sum(input, axis, builder=None):
"""
Returns the sum of all elements in the :code:`input` block along the provided :code:`axis`
:param input: the input values
:param axis: the dimension along which the reduction should be done
"""
return frontend.sum(input, axis, builder) return frontend.sum(input, axis, builder)
@@ -458,8 +525,11 @@ def debug_barrier(builder=None):
@builtin @builtin
def multiple_of(x, value, builder=None): def multiple_of(input, value, builder=None):
return frontend.multiple_of(x, value, builder) """
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
"""
return frontend.multiple_of(input, value, builder)
# ----------------------- # -----------------------
@@ -469,31 +539,65 @@ def multiple_of(x, value, builder=None):
@triton.jit @triton.jit
def minimum(x, y): def minimum(x, y):
"""
Computes the element-wise minimum of :code:`x` and :code:`y`.
:param input: the first input block
:type input: Block
:param other: the second input block
:type other: Block
"""
return triton.where(x < y, x, y) return triton.where(x < y, x, y)
@triton.jit @triton.jit
def maximum(x, y): def maximum(x, y):
"""
Computes the element-wise maximum of :code:`x` and :code:`y`.
:param input: the first input block
:type input: Block
:param other: the second input block
:type other: Block
"""
return triton.where(x > y, x, y) return triton.where(x > y, x, y)
@triton.jit @triton.jit
def sigmoid(x): def sigmoid(x):
"""
Computes the element-wise sigmoid of :code:`x`.
:param x: the input block
:type x: Block
"""
return 1 / (1 + np.exp(-x)) return 1 / (1 + np.exp(-x))
@triton.jit
def ravel(x):
return triton.reshape(x, [x.type.numel])
@triton.jit @triton.jit
def softmax(x): def softmax(x):
"""
Computes the element-wise softmax of :code:`x`.
:param x: the input block
:type x: Block
"""
z = x - triton.max(x, 0) z = x - triton.max(x, 0)
num = triton.exp(z) num = triton.exp(z)
den = triton.sum(num, 0) den = triton.sum(num, 0)
return num / den return num / den
@triton.jit
def ravel(x):
"""
Returns a contiguous flattened view of :code:`x`
:param x: the input block
:type x: Block
"""
return triton.reshape(x, [x.type.numel])
def cdiv(x, y): def cdiv(x, y):
return (x + y - 1) // y return (x + y - 1) // y

View File

@@ -1,5 +1,3 @@
import torch
import triton
""" """
Vector Addition Vector Addition
================= =================
@@ -14,6 +12,9 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# Compute Kernel # Compute Kernel
# -------------------------- # --------------------------
import torch
import triton
@triton.jit @triton.jit
def _add( def _add(

View File

@@ -100,7 +100,7 @@ def softmax(x):
# Allocate output # Allocate output
y = torch.empty_like(x) y = torch.empty_like(x)
# Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix # Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
_softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, BLOCK=BLOCK) _softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, num_warps=num_warps, BLOCK=BLOCK)
return y return y

View File

@@ -56,7 +56,7 @@ You will specifically learn about:
# Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as: # Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
# #
# .. code-block:: python # .. code-block:: python
# :force: #
# pid_m = triton.program_id(0) # pid_m = triton.program_id(0)
# pid_n = triton.program_id(1) # pid_n = triton.program_id(1)
# rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) # rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
@@ -85,8 +85,8 @@ You will specifically learn about:
# .. code-block:: Python # .. code-block:: Python
# #
# pid = triton.program_id(0); # pid = triton.program_id(0);
# grid_m = (M + BLOCK_M - 1) / BLOCK_M; # grid_m = (M + BLOCK_M - 1) // BLOCK_M;
# grid_n = (N + BLOCK_N - 1) / BLOCK_N; # grid_n = (N + BLOCK_N - 1) // BLOCK_N;
# pid_m = pid / grid_n; # pid_m = pid / grid_n;
# pid_n = pid % grid_n; # pid_n = pid % grid_n;
# #
@@ -95,15 +95,15 @@ You will specifically learn about:
# One possible solution is to launch blocks in an order that promotes data reuse. # One possible solution is to launch blocks in an order that promotes data reuse.
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column: # This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
# #
# .. code-block:: C # .. code-block:: python
# #
# pid = triton.program_id(0); # pid = triton.program_id(0);
# width = GROUP_M * grid_n; # width = GROUP_M * grid_n;
# group_id = pid / width; # group_id = pid // width;
# # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0 # # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
# group_size = min(grid_m - group_id * GROUP_M, GROUP_M); # group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
# pid_m = group_id * GROUP_M + (pid % group_size); # pid_m = group_id * GROUP_M + (pid % group_size);
# pid_n = (pid % width) / (group_size); # pid_n = (pid % width) // (group_size);
# #
# In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100). # In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
# #
@@ -237,7 +237,7 @@ print(triton.testing.allclose(c_0, c_1))
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[8192], # different possible values for `x_name` x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['cublas', 'triton'], # possible keys for `y_name` y_vals=['cublas', 'triton'], # possible keys for `y_name`
y_lines=["cuBLAS", "Triton"], # label name for the lines y_lines=["cuBLAS", "Triton"], # label name for the lines