[DOCS] Updates and improvements (#87)
This commit is contained in:
committed by
Philippe Tillet
parent
39f4730305
commit
29e33e50b7
35
docs/conf.py
35
docs/conf.py
@@ -23,17 +23,34 @@
|
|||||||
|
|
||||||
# -- General configuration ------------------------------------------------
|
# -- General configuration ------------------------------------------------
|
||||||
|
|
||||||
# If your documentation needs a minimal Sphinx version, state it here.
|
|
||||||
#
|
|
||||||
# needs_sphinx = '1.0'
|
|
||||||
|
|
||||||
# Add any Sphinx extension module names here, as strings. They can be
|
def setup(app):
|
||||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
"""Customize function args retrieving to get args under decorator."""
|
||||||
# ones.
|
import sphinx
|
||||||
extensions = []
|
import triton
|
||||||
|
|
||||||
|
def forward_jit_fn(func):
|
||||||
|
old = func
|
||||||
|
|
||||||
|
def wrapped(obj, **kwargs):
|
||||||
|
if isinstance(obj, triton.code_gen.JITFunction):
|
||||||
|
obj = obj.fn
|
||||||
|
return old(obj)
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
old_documenter = sphinx.ext.autosummary.get_documenter
|
||||||
|
|
||||||
|
def documenter(app, obj, parent):
|
||||||
|
if isinstance(obj, triton.code_gen.JITFunction):
|
||||||
|
obj = obj.fn
|
||||||
|
return old_documenter(app, obj, parent)
|
||||||
|
|
||||||
|
sphinx.ext.autosummary.get_documenter = documenter
|
||||||
|
sphinx.util.inspect.unwrap_all = forward_jit_fn(sphinx.util.inspect.unwrap_all)
|
||||||
|
sphinx.util.inspect.signature = forward_jit_fn(sphinx.util.inspect.signature)
|
||||||
|
sphinx.util.inspect.object_description = forward_jit_fn(sphinx.util.inspect.object_description)
|
||||||
|
|
||||||
# Math Jax
|
|
||||||
extensions += ['sphinx.ext.mathjax']
|
|
||||||
|
|
||||||
# Auto Doc
|
# Auto Doc
|
||||||
import sys
|
import sys
|
||||||
|
@@ -53,7 +53,7 @@ class CMakeBuild(build_ext):
|
|||||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||||
# create build directories
|
# create build directories
|
||||||
build_suffix = 'debug' if self.debug else 'release'
|
build_suffix = 'debug' if self.debug else 'release'
|
||||||
llvm_build_dir = os.path.join(tempfile.gettempdir(), f"llvm-{build_suffix}")
|
llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm-" + build_suffix)
|
||||||
if not os.path.exists(self.build_temp):
|
if not os.path.exists(self.build_temp):
|
||||||
os.makedirs(self.build_temp)
|
os.makedirs(self.build_temp)
|
||||||
if not os.path.exists(llvm_build_dir):
|
if not os.path.exists(llvm_build_dir):
|
||||||
|
@@ -5,6 +5,7 @@ import torch
|
|||||||
from .code_gen import jit, autotune, heuristics, Config, Autotuner
|
from .code_gen import jit, autotune, heuristics, Config, Autotuner
|
||||||
from .core import *
|
from .core import *
|
||||||
|
|
||||||
|
from . import code_gen
|
||||||
from . import testing
|
from . import testing
|
||||||
from . import ops
|
from . import ops
|
||||||
# version
|
# version
|
||||||
|
@@ -573,12 +573,14 @@ class Autotuner:
|
|||||||
|
|
||||||
class JITFunction:
|
class JITFunction:
|
||||||
def __init__(self, fn):
|
def __init__(self, fn):
|
||||||
|
self.fn = fn
|
||||||
self.module = fn.__module__
|
self.module = fn.__module__
|
||||||
self.arg_names = inspect.getfullargspec(fn).args
|
self.arg_names = inspect.getfullargspec(fn).args
|
||||||
self.cache = dict()
|
self.cache = dict()
|
||||||
self.kernel_decorators = []
|
self.kernel_decorators = []
|
||||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||||
self.kernel = None
|
self.kernel = None
|
||||||
|
self.__doc__ = fn.__doc__
|
||||||
|
|
||||||
# we do not parse in the constructor because
|
# we do not parse in the constructor because
|
||||||
# the user might want to monkey-patch self.src dynamically.
|
# the user might want to monkey-patch self.src dynamically.
|
||||||
|
@@ -56,8 +56,8 @@ def builtin(fn):
|
|||||||
|
|
||||||
if wrapper.__doc__:
|
if wrapper.__doc__:
|
||||||
wrapper.__doc__ += """\
|
wrapper.__doc__ += """\
|
||||||
:param builder: IR builder to generate code into, optional from within @triton.jit functions
|
:param builder: IR builder to generate code into
|
||||||
:type builder: triton.ir.builder
|
:type builder: triton.ir.builder, optional from within JIT'ed functions
|
||||||
"""
|
"""
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@@ -236,8 +236,7 @@ class block:
|
|||||||
@builtin
|
@builtin
|
||||||
def program_id(axis, builder=None):
|
def program_id(axis, builder=None):
|
||||||
"""
|
"""
|
||||||
Returns the id of the current program instance along the given `axis`.
|
Returns the id of the current program instance along the given :code:`axis`.
|
||||||
Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s.
|
|
||||||
|
|
||||||
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
||||||
:type axis: int
|
:type axis: int
|
||||||
@@ -248,7 +247,7 @@ def program_id(axis, builder=None):
|
|||||||
@builtin
|
@builtin
|
||||||
def num_programs(axis, builder=None):
|
def num_programs(axis, builder=None):
|
||||||
"""
|
"""
|
||||||
Returns the number of program instances launched along the given `axis`.
|
Returns the number of program instances launched along the given :code:`axis`.
|
||||||
|
|
||||||
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
||||||
:type axis: int
|
:type axis: int
|
||||||
@@ -264,11 +263,11 @@ def num_programs(axis, builder=None):
|
|||||||
@builtin
|
@builtin
|
||||||
def arange(start, end, builder=None):
|
def arange(start, end, builder=None):
|
||||||
"""
|
"""
|
||||||
Returns contiguous values within the open interval [start, end).
|
Returns contiguous values within the open interval [:code:`start`, :code:`end`).
|
||||||
|
|
||||||
:param start: Start of the interval.
|
:param start: Start of the interval. Must be a power of two.
|
||||||
:type start: int
|
:type start: int
|
||||||
:param stop: End of the interval.
|
:param stop: End of the interval. Must be a power of two >= start.
|
||||||
:type stop: int
|
:type stop: int
|
||||||
"""
|
"""
|
||||||
return frontend.arange(start, end, builder)
|
return frontend.arange(start, end, builder)
|
||||||
@@ -277,12 +276,12 @@ def arange(start, end, builder=None):
|
|||||||
@builtin
|
@builtin
|
||||||
def zeros(shape, dtype, builder=None):
|
def zeros(shape, dtype, builder=None):
|
||||||
"""
|
"""
|
||||||
Returns a block filled with the scalar value 0 and the given shape.
|
Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
|
||||||
|
|
||||||
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
||||||
:type shape: tuple of ints
|
:type shape: tuple of ints
|
||||||
:param dtype: Data-type of the new array, e.g., triton.float16
|
:param dtype: Data-type of the new array, e.g., :code:`triton.float16`
|
||||||
:type dtype: triton.ir.dtype
|
:type dtype: DType
|
||||||
"""
|
"""
|
||||||
return frontend.zeros(shape, dtype, builder)
|
return frontend.zeros(shape, dtype, builder)
|
||||||
|
|
||||||
@@ -295,12 +294,12 @@ def zeros(shape, dtype, builder=None):
|
|||||||
@builtin
|
@builtin
|
||||||
def broadcast(input, other, builder=None):
|
def broadcast(input, other, builder=None):
|
||||||
"""
|
"""
|
||||||
Tries to broadcast two blocks to a common compatible shape.
|
Tries to broadcast the two given blocks to a common compatible shape.
|
||||||
|
|
||||||
:param input: The first input block.
|
:param input: The first input block.
|
||||||
:type input: triton.ir.value
|
:type input: Block
|
||||||
:param other: The second input block.
|
:param other: The second input block.
|
||||||
:type other: triton.ir.value
|
:type other: Block
|
||||||
"""
|
"""
|
||||||
return frontend.broadcast(input, other, builder)
|
return frontend.broadcast(input, other, builder)
|
||||||
|
|
||||||
@@ -308,12 +307,12 @@ def broadcast(input, other, builder=None):
|
|||||||
@builtin
|
@builtin
|
||||||
def broadcast_to(input, shape, builder=None):
|
def broadcast_to(input, shape, builder=None):
|
||||||
"""
|
"""
|
||||||
Tries to broadcast a block to a new shape.
|
Tries to broadcast the given block to a new :code:`shape`.
|
||||||
|
|
||||||
:param input: The input block.
|
:param input: The input block.
|
||||||
:type input: triton.value
|
:type input: Block
|
||||||
:param shape: The new shape.
|
:param shape: The desired shape.
|
||||||
:type shape: tuple of int
|
:type shape: Tuple[int]
|
||||||
"""
|
"""
|
||||||
return frontend.broadcast_to(input, shape, builder)
|
return frontend.broadcast_to(input, shape, builder)
|
||||||
|
|
||||||
@@ -321,7 +320,13 @@ def broadcast_to(input, shape, builder=None):
|
|||||||
@builtin
|
@builtin
|
||||||
def reshape(input, shape, builder=None):
|
def reshape(input, shape, builder=None):
|
||||||
"""
|
"""
|
||||||
Reshapes a block to a new shape.
|
Tries to reshape the given block to a new shape.
|
||||||
|
|
||||||
|
:param input: The input block.
|
||||||
|
:type input:
|
||||||
|
:param shape: The desired shape.
|
||||||
|
:type shape: Tuple[int]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return frontend.reshape(input, shape, builder)
|
return frontend.reshape(input, shape, builder)
|
||||||
|
|
||||||
@@ -335,12 +340,13 @@ def reshape(input, shape, builder=None):
|
|||||||
def dot(input, other, builder=None):
|
def dot(input, other, builder=None):
|
||||||
"""
|
"""
|
||||||
Returns the matrix product of two blocks.
|
Returns the matrix product of two blocks.
|
||||||
|
|
||||||
The two blocks must be two dimensionals and have compatible inner dimensions.
|
The two blocks must be two dimensionals and have compatible inner dimensions.
|
||||||
|
|
||||||
:param input: The first block to be multiplied.
|
:param input: The first block to be multiplied.
|
||||||
:type input: 2D block of scalar-type in {`float16`, `float32`}
|
:type input: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
|
||||||
:param other: The second block to be multiplied.
|
:param other: The second block to be multiplied.
|
||||||
:type other: 2D block of scalar-type in {`float16`, `float32`}
|
:type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
|
||||||
"""
|
"""
|
||||||
return frontend.dot(input, other, builder)
|
return frontend.dot(input, other, builder)
|
||||||
|
|
||||||
@@ -353,14 +359,18 @@ def dot(input, other, builder=None):
|
|||||||
@builtin
|
@builtin
|
||||||
def load(pointer, mask=None, other=None, builder=None):
|
def load(pointer, mask=None, other=None, builder=None):
|
||||||
"""
|
"""
|
||||||
Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`.
|
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
|
||||||
|
|
||||||
:param pointer: Pointer to the data to be loaded.
|
:code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`.
|
||||||
:type pointer: Block of triton.pointer
|
|
||||||
:param mask: if mask[idx] is false, do not load the data at `pointer[idx]`.
|
:code:`other` is implicitly typecast to :code:`pointer.dtype.element_ty`.
|
||||||
:type mask: Block of triton.bool, optional
|
|
||||||
:param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]`
|
:param pointer: Pointers to the data to be loaded.
|
||||||
:type other: Block of triton.value, optional
|
:type pointer: Block of dtype=triton.PointerDType
|
||||||
|
:param mask: if mask[idx] is false, do not load the data at address :code:`pointer[idx]`.
|
||||||
|
:type mask: Block of triton.int1, optional
|
||||||
|
:param other: if mask[idx] is false, return other[idx]
|
||||||
|
:type other: Block, optional
|
||||||
"""
|
"""
|
||||||
return frontend.load(pointer, mask, other, builder)
|
return frontend.load(pointer, mask, other, builder)
|
||||||
|
|
||||||
@@ -368,26 +378,47 @@ def load(pointer, mask=None, other=None, builder=None):
|
|||||||
@builtin
|
@builtin
|
||||||
def store(pointer, value, mask=None, builder=None):
|
def store(pointer, value, mask=None, builder=None):
|
||||||
"""
|
"""
|
||||||
Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`.
|
Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
|
||||||
|
|
||||||
:param pointer: The memory locations where the elements of `value` are stored.
|
:code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`.
|
||||||
:type pointer: Block of triton.pointer
|
|
||||||
|
:param pointer: The memory locations where the elements of :code:`value` are stored.
|
||||||
|
:type pointer: Block of dtype=triton.PointerDType
|
||||||
:param value: The block of elements to be stored.
|
:param value: The block of elements to be stored.
|
||||||
:type value: Block of triton.value
|
:type value: Block
|
||||||
:param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`.
|
:param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`.
|
||||||
:type mask: Block of triton.bool, optional
|
:type mask: Block of triton.int1, optional
|
||||||
"""
|
"""
|
||||||
return frontend.store(pointer, value, mask, builder)
|
return frontend.store(pointer, value, mask, builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def atomic_cas(ptr, cmp, val, builder=None):
|
def atomic_cas(pointer, cmp, val, builder=None):
|
||||||
return frontend.atomic_cas(ptr, cmp, val, builder)
|
"""
|
||||||
|
Performs an atomic "compare-and-swap" and the memory locations specified by :code:`pointer`.
|
||||||
|
|
||||||
|
:param pointer: The memory locations to compare-and-swap.
|
||||||
|
:type pointer: Block of dtype=triton.PointerDType
|
||||||
|
:param cmp: The values expected to be found in the atomic object
|
||||||
|
:type cmp: Block of dtype=`pointer.dtype.element_ty`
|
||||||
|
:param val: The values to copy in case the expected value matches the contained value.
|
||||||
|
:type val: Block of dtype=`pointer.dtype.element_ty`
|
||||||
|
"""
|
||||||
|
|
||||||
|
return frontend.atomic_cas(pointer, cmp, val, builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def atomic_xchg(ptr, val, builder=None):
|
def atomic_xchg(pointer, val, builder=None):
|
||||||
return frontend.atomic_xchg(ptr, val, builder)
|
"""
|
||||||
|
Swaps the *old* values stored at location :code:`pointer` with the new values given by :code:`val`. Returns the old values.
|
||||||
|
|
||||||
|
:param pointer: The memory locations which contain the old values
|
||||||
|
:type pointer: Block of dtype=triton.PointerDType
|
||||||
|
:param val: The new values to store
|
||||||
|
:type val: Block of dtype=`pointer.dtype.element_ty`
|
||||||
|
"""
|
||||||
|
return frontend.atomic_xchg(pointer, val, builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
@@ -398,11 +429,14 @@ def atomic_xchg(ptr, val, builder=None):
|
|||||||
@builtin
|
@builtin
|
||||||
def where(condition, x, y, builder=None):
|
def where(condition, x, y, builder=None):
|
||||||
"""
|
"""
|
||||||
Returns a block of elements from either `x` or `y`, depending on `condition`.
|
Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
|
||||||
Note that `x` and `y` are always evaluated regardless of the value of `condition`.
|
|
||||||
If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead.
|
Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`.
|
||||||
The shape of `x` and `y` are both broadcast to the shape of `condition`.
|
|
||||||
`x` and `y` must have the data type.
|
If you want to avoid unintented memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead.
|
||||||
|
|
||||||
|
The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`.
|
||||||
|
:code:`x` and :code:`y` must have the data type.
|
||||||
|
|
||||||
:param condition: When True (nonzero), yield x, otherwise yield y.
|
:param condition: When True (nonzero), yield x, otherwise yield y.
|
||||||
:type condition: Block of triton.bool
|
:type condition: Block of triton.bool
|
||||||
@@ -419,11 +453,25 @@ def where(condition, x, y, builder=None):
|
|||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def exp(x, builder=None):
|
def exp(x, builder=None):
|
||||||
|
"""
|
||||||
|
Computes the element-wise exponential of :code:`x`
|
||||||
|
|
||||||
|
:param x: the input values
|
||||||
|
:type x: Block
|
||||||
|
"""
|
||||||
|
|
||||||
return frontend.exp(x, builder)
|
return frontend.exp(x, builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def log(x, builder=None):
|
def log(x, builder=None):
|
||||||
|
"""
|
||||||
|
Computes the element-wise natural logarithm of :code:`x`
|
||||||
|
|
||||||
|
:param x: the input values
|
||||||
|
:type x: Block
|
||||||
|
"""
|
||||||
|
|
||||||
return frontend.log(x, builder)
|
return frontend.log(x, builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -434,16 +482,35 @@ def log(x, builder=None):
|
|||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def max(input, axis, builder=None):
|
def max(input, axis, builder=None):
|
||||||
|
"""
|
||||||
|
Returns the maximum value of all elements in the :code:`input` block along the provided :code:`axis`
|
||||||
|
|
||||||
|
:param input: the input values
|
||||||
|
:param axis: the dimension along which the reduction should be done
|
||||||
|
"""
|
||||||
return frontend.max(input, axis, builder)
|
return frontend.max(input, axis, builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def min(input, axis, builder=None):
|
def min(input, axis, builder=None):
|
||||||
|
"""
|
||||||
|
Returns the minimum value of all elements in the :code:`input` block along the provided :code:`axis`
|
||||||
|
|
||||||
|
:param input: the input values
|
||||||
|
:param axis: the dimension along which the reduction should be done
|
||||||
|
"""
|
||||||
return frontend.min(input, axis, builder)
|
return frontend.min(input, axis, builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def sum(input, axis, builder=None):
|
def sum(input, axis, builder=None):
|
||||||
|
"""
|
||||||
|
Returns the sum of all elements in the :code:`input` block along the provided :code:`axis`
|
||||||
|
|
||||||
|
:param input: the input values
|
||||||
|
:param axis: the dimension along which the reduction should be done
|
||||||
|
"""
|
||||||
|
|
||||||
return frontend.sum(input, axis, builder)
|
return frontend.sum(input, axis, builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -458,8 +525,11 @@ def debug_barrier(builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def multiple_of(x, value, builder=None):
|
def multiple_of(input, value, builder=None):
|
||||||
return frontend.multiple_of(x, value, builder)
|
"""
|
||||||
|
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
|
||||||
|
"""
|
||||||
|
return frontend.multiple_of(input, value, builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
@@ -469,31 +539,65 @@ def multiple_of(x, value, builder=None):
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def minimum(x, y):
|
def minimum(x, y):
|
||||||
|
"""
|
||||||
|
Computes the element-wise minimum of :code:`x` and :code:`y`.
|
||||||
|
|
||||||
|
:param input: the first input block
|
||||||
|
:type input: Block
|
||||||
|
:param other: the second input block
|
||||||
|
:type other: Block
|
||||||
|
"""
|
||||||
return triton.where(x < y, x, y)
|
return triton.where(x < y, x, y)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def maximum(x, y):
|
def maximum(x, y):
|
||||||
|
"""
|
||||||
|
Computes the element-wise maximum of :code:`x` and :code:`y`.
|
||||||
|
|
||||||
|
:param input: the first input block
|
||||||
|
:type input: Block
|
||||||
|
:param other: the second input block
|
||||||
|
:type other: Block
|
||||||
|
"""
|
||||||
return triton.where(x > y, x, y)
|
return triton.where(x > y, x, y)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def sigmoid(x):
|
def sigmoid(x):
|
||||||
|
"""
|
||||||
|
Computes the element-wise sigmoid of :code:`x`.
|
||||||
|
|
||||||
|
:param x: the input block
|
||||||
|
:type x: Block
|
||||||
|
"""
|
||||||
return 1 / (1 + np.exp(-x))
|
return 1 / (1 + np.exp(-x))
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def ravel(x):
|
|
||||||
return triton.reshape(x, [x.type.numel])
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def softmax(x):
|
def softmax(x):
|
||||||
|
"""
|
||||||
|
Computes the element-wise softmax of :code:`x`.
|
||||||
|
|
||||||
|
:param x: the input block
|
||||||
|
:type x: Block
|
||||||
|
"""
|
||||||
z = x - triton.max(x, 0)
|
z = x - triton.max(x, 0)
|
||||||
num = triton.exp(z)
|
num = triton.exp(z)
|
||||||
den = triton.sum(num, 0)
|
den = triton.sum(num, 0)
|
||||||
return num / den
|
return num / den
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def ravel(x):
|
||||||
|
"""
|
||||||
|
Returns a contiguous flattened view of :code:`x`
|
||||||
|
|
||||||
|
:param x: the input block
|
||||||
|
:type x: Block
|
||||||
|
"""
|
||||||
|
return triton.reshape(x, [x.type.numel])
|
||||||
|
|
||||||
|
|
||||||
def cdiv(x, y):
|
def cdiv(x, y):
|
||||||
return (x + y - 1) // y
|
return (x + y - 1) // y
|
||||||
|
@@ -1,5 +1,3 @@
|
|||||||
import torch
|
|
||||||
import triton
|
|
||||||
"""
|
"""
|
||||||
Vector Addition
|
Vector Addition
|
||||||
=================
|
=================
|
||||||
@@ -14,6 +12,9 @@ In this tutorial, you will write a simple vector addition using Triton and learn
|
|||||||
# Compute Kernel
|
# Compute Kernel
|
||||||
# --------------------------
|
# --------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _add(
|
def _add(
|
||||||
|
@@ -100,7 +100,7 @@ def softmax(x):
|
|||||||
# Allocate output
|
# Allocate output
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
# Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
|
# Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
|
||||||
_softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, BLOCK=BLOCK)
|
_softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, num_warps=num_warps, BLOCK=BLOCK)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
@@ -56,7 +56,7 @@ You will specifically learn about:
|
|||||||
# Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
|
# Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
|
||||||
#
|
#
|
||||||
# .. code-block:: python
|
# .. code-block:: python
|
||||||
# :force:
|
#
|
||||||
# pid_m = triton.program_id(0)
|
# pid_m = triton.program_id(0)
|
||||||
# pid_n = triton.program_id(1)
|
# pid_n = triton.program_id(1)
|
||||||
# rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
# rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||||
@@ -85,8 +85,8 @@ You will specifically learn about:
|
|||||||
# .. code-block:: Python
|
# .. code-block:: Python
|
||||||
#
|
#
|
||||||
# pid = triton.program_id(0);
|
# pid = triton.program_id(0);
|
||||||
# grid_m = (M + BLOCK_M - 1) / BLOCK_M;
|
# grid_m = (M + BLOCK_M - 1) // BLOCK_M;
|
||||||
# grid_n = (N + BLOCK_N - 1) / BLOCK_N;
|
# grid_n = (N + BLOCK_N - 1) // BLOCK_N;
|
||||||
# pid_m = pid / grid_n;
|
# pid_m = pid / grid_n;
|
||||||
# pid_n = pid % grid_n;
|
# pid_n = pid % grid_n;
|
||||||
#
|
#
|
||||||
@@ -95,15 +95,15 @@ You will specifically learn about:
|
|||||||
# One possible solution is to launch blocks in an order that promotes data reuse.
|
# One possible solution is to launch blocks in an order that promotes data reuse.
|
||||||
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
|
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
|
||||||
#
|
#
|
||||||
# .. code-block:: C
|
# .. code-block:: python
|
||||||
#
|
#
|
||||||
# pid = triton.program_id(0);
|
# pid = triton.program_id(0);
|
||||||
# width = GROUP_M * grid_n;
|
# width = GROUP_M * grid_n;
|
||||||
# group_id = pid / width;
|
# group_id = pid // width;
|
||||||
# # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
|
# # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
|
||||||
# group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
|
# group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
|
||||||
# pid_m = group_id * GROUP_M + (pid % group_size);
|
# pid_m = group_id * GROUP_M + (pid % group_size);
|
||||||
# pid_n = (pid % width) / (group_size);
|
# pid_n = (pid % width) // (group_size);
|
||||||
#
|
#
|
||||||
# In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
# In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
||||||
#
|
#
|
||||||
@@ -237,7 +237,7 @@ print(triton.testing.allclose(c_0, c_1))
|
|||||||
@triton.testing.perf_report(
|
@triton.testing.perf_report(
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||||
x_vals=[8192], # different possible values for `x_name`
|
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
||||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
||||||
y_vals=['cublas', 'triton'], # possible keys for `y_name`
|
y_vals=['cublas', 'triton'], # possible keys for `y_name`
|
||||||
y_lines=["cuBLAS", "Triton"], # label name for the lines
|
y_lines=["cuBLAS", "Triton"], # label name for the lines
|
||||||
|
Reference in New Issue
Block a user