[DOCS] Various improvements (#224)
- Added docstr for autotune, Config, heuristics - Added docstr for atomics - Hiding internal _builder argument used for built-in language primitives - Re-factor docstr to use common templates between similar functions.
This commit is contained in:
10
docs/conf.py
10
docs/conf.py
@@ -24,11 +24,20 @@
|
|||||||
# -- General configuration ------------------------------------------------
|
# -- General configuration ------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def process_sig(app, what, name, obj, options, signature, return_annotation):
|
||||||
|
if signature and '_builder' in signature:
|
||||||
|
signature = signature.split('_builder')[0] + ")"
|
||||||
|
return (signature, return_annotation)
|
||||||
|
|
||||||
def setup(app):
|
def setup(app):
|
||||||
"""Customize function args retrieving to get args under decorator."""
|
"""Customize function args retrieving to get args under decorator."""
|
||||||
import sphinx
|
import sphinx
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
|
app.connect("autodoc-process-signature", process_sig)
|
||||||
|
|
||||||
def forward_jit_fn(func):
|
def forward_jit_fn(func):
|
||||||
old = func
|
old = func
|
||||||
|
|
||||||
@@ -39,6 +48,7 @@ def setup(app):
|
|||||||
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
old_documenter = sphinx.ext.autosummary.get_documenter
|
old_documenter = sphinx.ext.autosummary.get_documenter
|
||||||
|
|
||||||
def documenter(app, obj, parent):
|
def documenter(app, obj, parent):
|
||||||
|
@@ -98,6 +98,18 @@ Reduction Ops
|
|||||||
min
|
min
|
||||||
sum
|
sum
|
||||||
|
|
||||||
|
Atomic Ops
|
||||||
|
---------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
atomic_cas
|
||||||
|
atomic_add
|
||||||
|
atomic_max
|
||||||
|
atomic_min
|
||||||
|
|
||||||
|
|
||||||
Comparison ops
|
Comparison ops
|
||||||
---------------
|
---------------
|
||||||
|
@@ -8,3 +8,6 @@ triton
|
|||||||
:nosignatures:
|
:nosignatures:
|
||||||
|
|
||||||
jit
|
jit
|
||||||
|
autotune
|
||||||
|
heuristics
|
||||||
|
Config
|
@@ -193,11 +193,11 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
kws = dict()
|
kws = dict()
|
||||||
|
|
||||||
if self.is_triton_object(lhs):
|
if self.is_triton_object(lhs):
|
||||||
kws['builder'] = self.builder
|
kws['_builder'] = self.builder
|
||||||
ret = getattr(lhs, fn)(rhs, **kws)
|
ret = getattr(lhs, fn)(rhs, **kws)
|
||||||
if ret is NotImplemented:
|
if ret is NotImplemented:
|
||||||
if self.is_triton_object(rhs):
|
if self.is_triton_object(rhs):
|
||||||
kws['builder'] = self.builder
|
kws['_builder'] = self.builder
|
||||||
fn = fn[:2] + 'r' + fn[2:]
|
fn = fn[:2] + 'r' + fn[2:]
|
||||||
ret = getattr(rhs, fn)(lhs, **kws)
|
ret = getattr(rhs, fn)(lhs, **kws)
|
||||||
return ret
|
return ret
|
||||||
@@ -260,10 +260,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ast.IsNot: '__ne__',
|
ast.IsNot: '__ne__',
|
||||||
}[type(node.ops[0])]
|
}[type(node.ops[0])]
|
||||||
if self.is_triton_object(lhs):
|
if self.is_triton_object(lhs):
|
||||||
return getattr(lhs, fn)(rhs, builder=self.builder)
|
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||||
elif self.is_triton_object(rhs):
|
elif self.is_triton_object(rhs):
|
||||||
fn = fn[:2] + 'r' + fn[2:]
|
fn = fn[:2] + 'r' + fn[2:]
|
||||||
return getattr(rhs, fn)(lhs, builder=self.builder)
|
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||||
else:
|
else:
|
||||||
return getattr(lhs, fn)(rhs)
|
return getattr(lhs, fn)(rhs)
|
||||||
|
|
||||||
@@ -275,7 +275,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ast.Invert: '__invert__',
|
ast.Invert: '__invert__',
|
||||||
}[type(node.op)]
|
}[type(node.op)]
|
||||||
if self.is_triton_object(op):
|
if self.is_triton_object(op):
|
||||||
return getattr(op, fn)(builder=self.builder)
|
return getattr(op, fn)(_builder=self.builder)
|
||||||
return getattr(op, fn)()
|
return getattr(op, fn)()
|
||||||
|
|
||||||
def visit_While(self, node):
|
def visit_While(self, node):
|
||||||
@@ -308,7 +308,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
lhs = self.visit(node.value)
|
lhs = self.visit(node.value)
|
||||||
slices = self.visit(node.slice)
|
slices = self.visit(node.slice)
|
||||||
if self.is_triton_object(lhs):
|
if self.is_triton_object(lhs):
|
||||||
return lhs.__getitem__(slices, builder=self.builder)
|
return lhs.__getitem__(slices, _builder=self.builder)
|
||||||
return lhs[slices]
|
return lhs[slices]
|
||||||
|
|
||||||
def visit_ExtSlice(self, node):
|
def visit_ExtSlice(self, node):
|
||||||
@@ -331,7 +331,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
|
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
|
||||||
self.visit(pos_cond_node),\
|
self.visit(pos_cond_node),\
|
||||||
self.visit(neg_cond_node),\
|
self.visit(neg_cond_node),\
|
||||||
builder=self.builder)
|
_builder=self.builder)
|
||||||
#cond_node = neg_cond_node
|
#cond_node = neg_cond_node
|
||||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||||
# code generation
|
# code generation
|
||||||
@@ -385,7 +385,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
return fn(*args, generator=self, **kws)
|
return fn(*args, generator=self, **kws)
|
||||||
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
||||||
sys.modules[fn.__module__] is triton.language:
|
sys.modules[fn.__module__] is triton.language:
|
||||||
return fn(*args, builder=self.builder, **kws)
|
return fn(*args, _builder=self.builder, **kws)
|
||||||
return fn(*args, **kws)
|
return fn(*args, **kws)
|
||||||
|
|
||||||
def visit_Num(self, node):
|
def visit_Num(self, node):
|
||||||
@@ -714,6 +714,19 @@ class JITFunction:
|
|||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
"""
|
||||||
|
An object that represents a possible kernel configuration for the auto-tuner to try.
|
||||||
|
|
||||||
|
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
||||||
|
:type meta: dict[Str, Any]
|
||||||
|
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
|
||||||
|
`num_warps=8`, then each kernel instance will be automatically parallelized to
|
||||||
|
cooperatively execute using `8 * 32 = 256` threads.
|
||||||
|
:type num_warps: int
|
||||||
|
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
||||||
|
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
||||||
|
:type num_stages: int
|
||||||
|
"""
|
||||||
def __init__(self, meta, num_warps=4, num_stages=2):
|
def __init__(self, meta, num_warps=4, num_stages=2):
|
||||||
self.meta = meta
|
self.meta = meta
|
||||||
self.num_warps = num_warps
|
self.num_warps = num_warps
|
||||||
@@ -721,6 +734,35 @@ class Config:
|
|||||||
|
|
||||||
|
|
||||||
def autotune(configs, key, reset_to_zero=None):
|
def autotune(configs, key, reset_to_zero=None):
|
||||||
|
"""
|
||||||
|
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||||
|
|
||||||
|
.. highlight:: python
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@triton.autotune(configs=[
|
||||||
|
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||||
|
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||||
|
],
|
||||||
|
key=['x_size'] # the two above configs will be evaluated anytime
|
||||||
|
# the value of x_size changes
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def kernel(x_ptr, x_size, **META):
|
||||||
|
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||||
|
|
||||||
|
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||||
|
This means that whatever value the kernel updates will be updated multiple times.
|
||||||
|
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||||
|
reset the value of the provided tensor to `zero` before running any configuration.
|
||||||
|
|
||||||
|
:param configs: a list of :code:`triton.Config` objects
|
||||||
|
:type configs: list[triton.Config]
|
||||||
|
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||||
|
:type key: list[str]
|
||||||
|
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||||
|
:type reset_to_zero: list[str]
|
||||||
|
"""
|
||||||
def decorator(fn):
|
def decorator(fn):
|
||||||
def wrapper(kernel):
|
def wrapper(kernel):
|
||||||
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero)
|
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero)
|
||||||
@@ -732,6 +774,23 @@ def autotune(configs, key, reset_to_zero=None):
|
|||||||
|
|
||||||
|
|
||||||
def heuristics(values):
|
def heuristics(values):
|
||||||
|
"""
|
||||||
|
Decorator for specifying how the values of certain meta-parameters may be computed.
|
||||||
|
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
|
||||||
|
|
||||||
|
.. highlight:: python
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
|
||||||
|
@triton.jit
|
||||||
|
def kernel(x_ptr, x_size, **META):
|
||||||
|
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
|
||||||
|
|
||||||
|
|
||||||
|
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
||||||
|
each such function takes a list of positional arguments as input.
|
||||||
|
.type values: dict[str, Callable[[list[Any]], Any]]
|
||||||
|
"""
|
||||||
def decorator(fn):
|
def decorator(fn):
|
||||||
def wrapper(kernel):
|
def wrapper(kernel):
|
||||||
def fun(*args, **meta):
|
def fun(*args, **meta):
|
||||||
@@ -767,6 +826,8 @@ def jit(fn):
|
|||||||
return JITFunction(fn)
|
return JITFunction(fn)
|
||||||
|
|
||||||
|
|
||||||
|
######
|
||||||
|
|
||||||
def cdiv(x, y):
|
def cdiv(x, y):
|
||||||
return (x + y - 1) // y
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
@@ -49,16 +49,11 @@ for name in dir(frontend):
|
|||||||
def builtin(fn):
|
def builtin(fn):
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
if 'builder' not in kwargs or \
|
if '_builder' not in kwargs or \
|
||||||
kwargs['builder'] is None:
|
kwargs['_builder'] is None:
|
||||||
raise ValueError("Builder argument must be provided outside of JIT functions. Did you forget to add @triton.jit ?")
|
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
if wrapper.__doc__:
|
|
||||||
wrapper.__doc__ += """\
|
|
||||||
:param builder: IR builder to generate code into
|
|
||||||
:type builder: triton.ir.builder, optional from within JIT'ed functions
|
|
||||||
"""
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@@ -124,120 +119,120 @@ class block:
|
|||||||
self.dtype = block._init_dtype(self.handle.type.scalar)
|
self.dtype = block._init_dtype(self.handle.type.scalar)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __add__(self, other, builder=None):
|
def __add__(self, other, _builder=None):
|
||||||
return frontend.add(self, other, builder)
|
return frontend.add(self, other, _builder)
|
||||||
|
|
||||||
def __radd__(self, other, builder=None):
|
def __radd__(self, other, _builder=None):
|
||||||
return self.__add__(other, builder=builder)
|
return self.__add__(other, _builder=_builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __sub__(self, other, builder=None):
|
def __sub__(self, other, _builder=None):
|
||||||
return frontend.sub(self, other, builder)
|
return frontend.sub(self, other, _builder)
|
||||||
|
|
||||||
def __rsub__(self, other, builder=None):
|
def __rsub__(self, other, _builder=None):
|
||||||
return frontend.sub(other, self, builder)
|
return frontend.sub(other, self, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __mul__(self, other, builder=None):
|
def __mul__(self, other, _builder=None):
|
||||||
return frontend.mul(self, other, builder)
|
return frontend.mul(self, other, _builder)
|
||||||
|
|
||||||
def __rmul__(self, other, builder=None):
|
def __rmul__(self, other, _builder=None):
|
||||||
return self.__mul__(other, builder=builder)
|
return self.__mul__(other, _builder=_builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __truediv__(self, other, builder=None):
|
def __truediv__(self, other, _builder=None):
|
||||||
return frontend.truediv(self, other, builder)
|
return frontend.truediv(self, other, _builder)
|
||||||
|
|
||||||
def __rtruediv__(self, other, builder=None):
|
def __rtruediv__(self, other, _builder=None):
|
||||||
return frontend.truediv(other, self, builder)
|
return frontend.truediv(other, self, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __floordiv__(self, other, builder=None):
|
def __floordiv__(self, other, _builder=None):
|
||||||
return frontend.floordiv(self, other, builder)
|
return frontend.floordiv(self, other, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __mod__(self, other, builder=None):
|
def __mod__(self, other, _builder=None):
|
||||||
return frontend.mod(self, other, builder)
|
return frontend.mod(self, other, _builder)
|
||||||
|
|
||||||
# unary operators
|
# unary operators
|
||||||
@builtin
|
@builtin
|
||||||
def __neg__(self, builder=None):
|
def __neg__(self, _builder=None):
|
||||||
return frontend.minus(self, builder)
|
return frontend.minus(self, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __invert__(self, builder=None):
|
def __invert__(self, _builder=None):
|
||||||
return frontend.invert(self, builder)
|
return frontend.invert(self, _builder)
|
||||||
|
|
||||||
# bitwise operators
|
# bitwise operators
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __and__(self, other, builder=None):
|
def __and__(self, other, _builder=None):
|
||||||
return frontend.and_(self, other, builder)
|
return frontend.and_(self, other, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __or__(self, other, builder=None):
|
def __or__(self, other, _builder=None):
|
||||||
return frontend.or_(self, other, builder)
|
return frontend.or_(self, other, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __xor__(self, other, builder=None):
|
def __xor__(self, other, _builder=None):
|
||||||
return frontend.xor_(self, other, builder)
|
return frontend.xor_(self, other, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __lshift__(self, other, builder=None):
|
def __lshift__(self, other, _builder=None):
|
||||||
return frontend.shl(self, other, builder)
|
return frontend.shl(self, other, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __rshift__(self, other, builder=None):
|
def __rshift__(self, other, _builder=None):
|
||||||
return frontend.lshr(self, other, builder)
|
return frontend.lshr(self, other, _builder)
|
||||||
|
|
||||||
# comparison operators
|
# comparison operators
|
||||||
|
|
||||||
# >
|
# >
|
||||||
@builtin
|
@builtin
|
||||||
def __gt__(self, other, builder=None):
|
def __gt__(self, other, _builder=None):
|
||||||
return frontend.greater_than(self, other, builder)
|
return frontend.greater_than(self, other, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __rgt__(self, other, builder=None):
|
def __rgt__(self, other, _builder=None):
|
||||||
return frontend.greater_than(other, self, builder)
|
return frontend.greater_than(other, self, _builder)
|
||||||
|
|
||||||
# >=
|
# >=
|
||||||
@builtin
|
@builtin
|
||||||
def __ge__(self, other, builder=None):
|
def __ge__(self, other, _builder=None):
|
||||||
return frontend.greater_equal(self, other, builder)
|
return frontend.greater_equal(self, other, _builder)
|
||||||
|
|
||||||
def __rge__(self, other, builder=None):
|
def __rge__(self, other, _builder=None):
|
||||||
return frontend.greater_equal(other, self, builder)
|
return frontend.greater_equal(other, self, _builder)
|
||||||
|
|
||||||
# <
|
# <
|
||||||
@builtin
|
@builtin
|
||||||
def __lt__(self, other, builder=None):
|
def __lt__(self, other, _builder=None):
|
||||||
return frontend.less_than(self, other, builder)
|
return frontend.less_than(self, other, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __rlt__(self, other, builder=None):
|
def __rlt__(self, other, _builder=None):
|
||||||
return frontend.less_than(other, self, builder)
|
return frontend.less_than(other, self, _builder)
|
||||||
|
|
||||||
# <=
|
# <=
|
||||||
@builtin
|
@builtin
|
||||||
def __le__(self, other, builder=None):
|
def __le__(self, other, _builder=None):
|
||||||
return frontend.less_equal(self, other, builder)
|
return frontend.less_equal(self, other, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __rle__(self, other, builder=None):
|
def __rle__(self, other, _builder=None):
|
||||||
return frontend.less_equal(other, self, builder)
|
return frontend.less_equal(other, self, _builder)
|
||||||
|
|
||||||
# ==
|
# ==
|
||||||
@builtin
|
@builtin
|
||||||
def __eq__(self, other, builder=None):
|
def __eq__(self, other, _builder=None):
|
||||||
return frontend.equal(self, other, builder)
|
return frontend.equal(self, other, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __ne__(self, other, builder=None):
|
def __ne__(self, other, _builder=None):
|
||||||
return frontend.not_equal(self, other, builder)
|
return frontend.not_equal(self, other, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __getitem__(self, slices, builder=None):
|
def __getitem__(self, slices, _builder=None):
|
||||||
if isinstance(slices, slice):
|
if isinstance(slices, slice):
|
||||||
slices = [slices]
|
slices = [slices]
|
||||||
src_shape = self.shape
|
src_shape = self.shape
|
||||||
@@ -249,15 +244,15 @@ class block:
|
|||||||
elif sl == slice(None, None, None):
|
elif sl == slice(None, None, None):
|
||||||
dst_shape.append(src_shape[curr])
|
dst_shape.append(src_shape[curr])
|
||||||
curr += 1
|
curr += 1
|
||||||
ret = frontend.reshape(self, dst_shape, builder)
|
ret = frontend.reshape(self, dst_shape, _builder)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def to(self, dtype, bitcast=False, builder=None):
|
def to(self, dtype, bitcast=False, _builder=None):
|
||||||
dtype = dtype.handle(builder)
|
dtype = dtype.handle(_builder)
|
||||||
if bitcast:
|
if bitcast:
|
||||||
return frontend.bitcast(self, dtype, builder)
|
return frontend.bitcast(self, dtype, _builder)
|
||||||
return frontend.cast(self, dtype, builder)
|
return frontend.cast(self, dtype, _builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
@@ -266,25 +261,25 @@ class block:
|
|||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def program_id(axis, builder=None):
|
def program_id(axis, _builder=None):
|
||||||
"""
|
"""
|
||||||
Returns the id of the current program instance along the given :code:`axis`.
|
Returns the id of the current program instance along the given :code:`axis`.
|
||||||
|
|
||||||
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
||||||
:type axis: int
|
:type axis: int
|
||||||
"""
|
"""
|
||||||
return frontend.program_id(axis, builder)
|
return frontend.program_id(axis, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def num_programs(axis, builder=None):
|
def num_programs(axis, _builder=None):
|
||||||
"""
|
"""
|
||||||
Returns the number of program instances launched along the given :code:`axis`.
|
Returns the number of program instances launched along the given :code:`axis`.
|
||||||
|
|
||||||
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
||||||
:type axis: int
|
:type axis: int
|
||||||
"""
|
"""
|
||||||
return frontend.num_programs(axis, builder)
|
return frontend.num_programs(axis, _builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
@@ -293,7 +288,7 @@ def num_programs(axis, builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def arange(start, end, builder=None):
|
def arange(start, end, _builder=None):
|
||||||
"""
|
"""
|
||||||
Returns contiguous values within the open interval [:code:`start`, :code:`end`).
|
Returns contiguous values within the open interval [:code:`start`, :code:`end`).
|
||||||
|
|
||||||
@@ -302,11 +297,11 @@ def arange(start, end, builder=None):
|
|||||||
:param stop: End of the interval. Must be a power of two >= start.
|
:param stop: End of the interval. Must be a power of two >= start.
|
||||||
:type stop: int
|
:type stop: int
|
||||||
"""
|
"""
|
||||||
return frontend.arange(start, end, builder)
|
return frontend.arange(start, end, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def zeros(shape, dtype, builder=None):
|
def zeros(shape, dtype, _builder=None):
|
||||||
"""
|
"""
|
||||||
Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
|
Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
|
||||||
|
|
||||||
@@ -316,7 +311,7 @@ def zeros(shape, dtype, builder=None):
|
|||||||
:type dtype: DType
|
:type dtype: DType
|
||||||
"""
|
"""
|
||||||
shape = [int(x.handle) if isinstance(x, block) else x for x in shape]
|
shape = [int(x.handle) if isinstance(x, block) else x for x in shape]
|
||||||
return frontend.zeros(shape, dtype, builder)
|
return frontend.zeros(shape, dtype, _builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
@@ -325,7 +320,7 @@ def zeros(shape, dtype, builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def broadcast(input, other, builder=None):
|
def broadcast(input, other, _builder=None):
|
||||||
"""
|
"""
|
||||||
Tries to broadcast the two given blocks to a common compatible shape.
|
Tries to broadcast the two given blocks to a common compatible shape.
|
||||||
|
|
||||||
@@ -334,11 +329,11 @@ def broadcast(input, other, builder=None):
|
|||||||
:param other: The second input block.
|
:param other: The second input block.
|
||||||
:type other: Block
|
:type other: Block
|
||||||
"""
|
"""
|
||||||
return frontend.broadcast(input, other, builder)
|
return frontend.broadcast(input, other, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def broadcast_to(input, shape, builder=None):
|
def broadcast_to(input, shape, _builder=None):
|
||||||
"""
|
"""
|
||||||
Tries to broadcast the given block to a new :code:`shape`.
|
Tries to broadcast the given block to a new :code:`shape`.
|
||||||
|
|
||||||
@@ -347,11 +342,11 @@ def broadcast_to(input, shape, builder=None):
|
|||||||
:param shape: The desired shape.
|
:param shape: The desired shape.
|
||||||
:type shape: Tuple[int]
|
:type shape: Tuple[int]
|
||||||
"""
|
"""
|
||||||
return frontend.broadcast_to(input, shape, builder)
|
return frontend.broadcast_to(input, shape, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def reshape(input, shape, builder=None):
|
def reshape(input, shape, _builder=None):
|
||||||
"""
|
"""
|
||||||
Tries to reshape the given block to a new shape.
|
Tries to reshape the given block to a new shape.
|
||||||
|
|
||||||
@@ -361,7 +356,7 @@ def reshape(input, shape, builder=None):
|
|||||||
:type shape: Tuple[int]
|
:type shape: Tuple[int]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return frontend.reshape(input, shape, builder)
|
return frontend.reshape(input, shape, _builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
@@ -370,7 +365,7 @@ def reshape(input, shape, builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def dot(input, other, builder=None):
|
def dot(input, other, _builder=None):
|
||||||
"""
|
"""
|
||||||
Returns the matrix product of two blocks.
|
Returns the matrix product of two blocks.
|
||||||
|
|
||||||
@@ -381,16 +376,16 @@ def dot(input, other, builder=None):
|
|||||||
:param other: The second block to be multiplied.
|
:param other: The second block to be multiplied.
|
||||||
:type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
|
:type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
|
||||||
"""
|
"""
|
||||||
return frontend.dot(input, other, builder)
|
return frontend.dot(input, other, _builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
# Memory Operations
|
# Non-Atomic Memory Operations
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def load(pointer, mask=None, other=None, builder=None):
|
def load(pointer, mask=None, other=None, _builder=None):
|
||||||
"""
|
"""
|
||||||
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
|
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
|
||||||
|
|
||||||
@@ -405,11 +400,11 @@ def load(pointer, mask=None, other=None, builder=None):
|
|||||||
:param other: if mask[idx] is false, return other[idx]
|
:param other: if mask[idx] is false, return other[idx]
|
||||||
:type other: Block, optional
|
:type other: Block, optional
|
||||||
"""
|
"""
|
||||||
return frontend.load(pointer, mask, other, builder)
|
return frontend.load(pointer, mask, other, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def store(pointer, value, mask=None, builder=None):
|
def store(pointer, value, mask=None, _builder=None):
|
||||||
"""
|
"""
|
||||||
Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
|
Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
|
||||||
|
|
||||||
@@ -422,13 +417,20 @@ def store(pointer, value, mask=None, builder=None):
|
|||||||
:param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`.
|
:param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`.
|
||||||
:type mask: Block of triton.int1, optional
|
:type mask: Block of triton.int1, optional
|
||||||
"""
|
"""
|
||||||
return frontend.store(pointer, value, mask, builder)
|
return frontend.store(pointer, value, mask, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
# -----------------------
|
||||||
def atomic_cas(pointer, cmp, val, builder=None):
|
# Atomic Memory Operations
|
||||||
"""
|
# -----------------------
|
||||||
Performs an atomic "compare-and-swap" and the memory locations specified by :code:`pointer`.
|
|
||||||
|
def _add_atomic_docstr(name):
|
||||||
|
|
||||||
|
def _decorator(func):
|
||||||
|
docstr = """
|
||||||
|
Performs an atomic {name} at the memory location specified by :code:`pointer`.
|
||||||
|
|
||||||
|
Return the data stored at :code:`pointer` before the atomic operation.
|
||||||
|
|
||||||
:param pointer: The memory locations to compare-and-swap.
|
:param pointer: The memory locations to compare-and-swap.
|
||||||
:type pointer: Block of dtype=triton.PointerDType
|
:type pointer: Block of dtype=triton.PointerDType
|
||||||
@@ -437,62 +439,56 @@ def atomic_cas(pointer, cmp, val, builder=None):
|
|||||||
:param val: The values to copy in case the expected value matches the contained value.
|
:param val: The values to copy in case the expected value matches the contained value.
|
||||||
:type val: Block of dtype=`pointer.dtype.element_ty`
|
:type val: Block of dtype=`pointer.dtype.element_ty`
|
||||||
"""
|
"""
|
||||||
|
func.__doc__ = docstr.format(name=name)
|
||||||
|
return func
|
||||||
|
|
||||||
return frontend.atomic_cas(pointer, cmp, val, builder)
|
return _decorator
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
@_add_atomic_docstr("compare-and-swap")
|
||||||
|
def atomic_cas(pointer, cmp, val, _builder=None):
|
||||||
|
return frontend.atomic_cas(pointer, cmp, val, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def atomic_xchg(pointer, val, mask=None, builder=None):
|
@_add_atomic_docstr("exchange")
|
||||||
"""
|
def atomic_xchg(pointer, val, mask=None, _builder=None):
|
||||||
Swaps the *old* values stored at location :code:`pointer` with the new values given by :code:`val`. Returns the old values.
|
return frontend.atomic_xchg(pointer, val, mask, _builder)
|
||||||
|
|
||||||
:param pointer: The memory locations which contain the old values
|
@builtin
|
||||||
:type pointer: Block of dtype=triton.PointerDType
|
@_add_atomic_docstr("add")
|
||||||
:param val: The new values to store
|
def atomic_add(pointer, val, mask=None, _builder=None):
|
||||||
:type val: Block of dtype=`pointer.dtype.element_ty`
|
return frontend.atomic_add(pointer, val, mask, _builder)
|
||||||
:param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected.
|
|
||||||
:type mask: Block of triton.int1, optional
|
|
||||||
"""
|
|
||||||
return frontend.atomic_xchg(pointer, val, mask, builder)
|
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def atomic_add(pointer, val, mask=None, builder=None):
|
@_add_atomic_docstr("max")
|
||||||
"""
|
def atomic_max(pointer, val, mask=None, _builder=None):
|
||||||
Performs an atomic add and the memory locations specified by :code:`pointer`.
|
return frontend.atomic_max(pointer, val, mask, _builder)
|
||||||
:param pointer: The memory locations which contain the old values
|
|
||||||
:type pointer: Block of dtype=triton.PointerDType
|
|
||||||
:param val: The values to add
|
|
||||||
:type val: Block of dtype=`pointer.dtype.element_ty`
|
|
||||||
:param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected.
|
|
||||||
:type mask: Block of triton.int1, optional
|
|
||||||
"""
|
|
||||||
return frontend.atomic_add(pointer, val, mask, builder)
|
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def atomic_max(pointer, val, mask=None, builder=None):
|
@_add_atomic_docstr("min")
|
||||||
return frontend.atomic_max(pointer, val, mask, builder)
|
def atomic_min(pointer, val, mask=None, _builder=None):
|
||||||
|
return frontend.atomic_min(pointer, val, mask, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def atomic_min(pointer, val, mask=None, builder=None):
|
@_add_atomic_docstr("logical and")
|
||||||
return frontend.atomic_min(pointer, val, mask, builder)
|
def atomic_and(pointer, val, mask=None, _builder=None):
|
||||||
|
return frontend.atomic_and(pointer, val, mask, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def atomic_and(pointer, val, mask=None, builder=None):
|
@_add_atomic_docstr("logical or")
|
||||||
return frontend.atomic_and(pointer, val, mask, builder)
|
def atomic_or(pointer, val, mask=None, _builder=None):
|
||||||
|
return frontend.atomic_or(pointer, val, mask, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def atomic_or(pointer, val, mask=None, builder=None):
|
@_add_atomic_docstr("logical xor")
|
||||||
return frontend.atomic_or(pointer, val, mask, builder)
|
def atomic_xor(pointer, val, mask=None, _builder=None):
|
||||||
|
return frontend.atomic_xor(pointer, val, mask, _builder)
|
||||||
|
|
||||||
@builtin
|
|
||||||
def atomic_xor(pointer, val, mask=None, builder=None):
|
|
||||||
return frontend.atomic_xor(pointer, val, mask, builder)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
@@ -501,7 +497,7 @@ def atomic_xor(pointer, val, mask=None, builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def where(condition, x, y, builder=None):
|
def where(condition, x, y, _builder=None):
|
||||||
"""
|
"""
|
||||||
Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
|
Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
|
||||||
|
|
||||||
@@ -517,108 +513,89 @@ def where(condition, x, y, builder=None):
|
|||||||
:param x: values selected at indices where condition is True.
|
:param x: values selected at indices where condition is True.
|
||||||
:param y: values selected at indices where condition is False.
|
:param y: values selected at indices where condition is False.
|
||||||
"""
|
"""
|
||||||
return frontend.where(condition, x, y, builder)
|
return frontend.where(condition, x, y, _builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
# Math
|
# Math
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
|
||||||
|
def _add_math_1arg_docstr(name):
|
||||||
|
|
||||||
@builtin
|
def _decorator(func):
|
||||||
def exp(x, builder=None):
|
docstr = """
|
||||||
"""
|
Computes the element-wise {name} of :code:`x`
|
||||||
Computes the element-wise exponential of :code:`x`
|
|
||||||
|
|
||||||
:param x: the input values
|
:param x: the input values
|
||||||
:type x: Block
|
:type x: Block
|
||||||
"""
|
"""
|
||||||
|
func.__doc__ = docstr.format(name=name)
|
||||||
|
return func
|
||||||
|
|
||||||
return frontend.exp(x, builder)
|
return _decorator
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
@_add_math_1arg_docstr("exponential")
|
||||||
|
def exp(x, _builder=None):
|
||||||
|
return frontend.exp(x, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def log(x, builder=None):
|
@_add_math_1arg_docstr("natural logarithm")
|
||||||
"""
|
def log(x, _builder=None):
|
||||||
Computes the element-wise natural logarithm of :code:`x`
|
return frontend.log(x, _builder)
|
||||||
|
|
||||||
:param x: the input values
|
|
||||||
:type x: Block
|
|
||||||
"""
|
|
||||||
|
|
||||||
return frontend.log(x, builder)
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def cos(x, builder=None):
|
@_add_math_1arg_docstr("cosine")
|
||||||
"""
|
def cos(x, _builder=None):
|
||||||
Computes the element-wise cosine of :code:`x`
|
return frontend.cos(x, _builder)
|
||||||
|
|
||||||
:param x: the input values
|
|
||||||
:type x: Block
|
|
||||||
"""
|
|
||||||
|
|
||||||
return frontend.cos(x, builder)
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def sin(x, builder=None):
|
@_add_math_1arg_docstr("sine")
|
||||||
"""
|
def sin(x, _builder=None):
|
||||||
Computes the element-wise sine of :code:`x`
|
return frontend.sin(x, _builder)
|
||||||
|
|
||||||
:param x: the input values
|
|
||||||
:type x: Block
|
|
||||||
"""
|
|
||||||
|
|
||||||
return frontend.sin(x, builder)
|
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def sqrt(x, builder=None):
|
@_add_math_1arg_docstr("square root")
|
||||||
"""
|
def sqrt(x, _builder=None):
|
||||||
Computes the element-wise square root of :code:`x`
|
return frontend.sqrt(x, _builder)
|
||||||
|
|
||||||
:param x: the input values
|
|
||||||
:type x: Block
|
|
||||||
"""
|
|
||||||
return frontend.sqrt(x, builder)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
# Reductions
|
# Reductions
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
|
||||||
|
def _add_reduction_docstr(name):
|
||||||
|
|
||||||
@builtin
|
def _decorator(func):
|
||||||
def max(input, axis, builder=None):
|
docstr = """
|
||||||
"""
|
Returns the {name} of all elements in the :code:`input` block along the provided :code:`axis`
|
||||||
Returns the maximum value of all elements in the :code:`input` block along the provided :code:`axis`
|
|
||||||
|
|
||||||
:param input: the input values
|
:param input: the input values
|
||||||
:param axis: the dimension along which the reduction should be done
|
:param axis: the dimension along which the reduction should be done
|
||||||
"""
|
"""
|
||||||
return frontend.max(input, axis, builder)
|
func.__doc__ = docstr.format(name=name)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return _decorator
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
@_add_reduction_docstr("maximum")
|
||||||
|
def max(input, axis, _builder=None):
|
||||||
|
return frontend.max(input, axis, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def min(input, axis, builder=None):
|
@_add_reduction_docstr("minimum")
|
||||||
"""
|
def min(input, axis, _builder=None):
|
||||||
Returns the minimum value of all elements in the :code:`input` block along the provided :code:`axis`
|
return frontend.min(input, axis, _builder)
|
||||||
|
|
||||||
:param input: the input values
|
|
||||||
:param axis: the dimension along which the reduction should be done
|
|
||||||
"""
|
|
||||||
return frontend.min(input, axis, builder)
|
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def sum(input, axis, builder=None):
|
@_add_reduction_docstr("sum")
|
||||||
"""
|
def sum(input, axis, _builder=None):
|
||||||
Returns the sum of all elements in the :code:`input` block along the provided :code:`axis`
|
return frontend.sum(input, axis, _builder)
|
||||||
|
|
||||||
:param input: the input values
|
|
||||||
:param axis: the dimension along which the reduction should be done
|
|
||||||
"""
|
|
||||||
|
|
||||||
return frontend.sum(input, axis, builder)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
@@ -627,24 +604,24 @@ def sum(input, axis, builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def debug_barrier(builder=None):
|
def debug_barrier(_builder=None):
|
||||||
return frontend.debug_barrier(builder)
|
return frontend.debug_barrier(_builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def multiple_of(input, value, builder=None):
|
def multiple_of(input, value, _builder=None):
|
||||||
"""
|
"""
|
||||||
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
|
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
|
||||||
"""
|
"""
|
||||||
return frontend.multiple_of(input, value, builder)
|
return frontend.multiple_of(input, value, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def max_contiguous(input, value, builder=None):
|
def max_contiguous(input, value, _builder=None):
|
||||||
"""
|
"""
|
||||||
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
|
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
|
||||||
"""
|
"""
|
||||||
return frontend.max_contiguous(input, value, builder)
|
return frontend.max_contiguous(input, value, _builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
@@ -690,24 +667,14 @@ def maximum(x, y):
|
|||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
@_add_math_1arg_docstr("sigmoid")
|
||||||
def sigmoid(x):
|
def sigmoid(x):
|
||||||
"""
|
|
||||||
Computes the element-wise sigmoid of :code:`x`.
|
|
||||||
|
|
||||||
:param x: the input block
|
|
||||||
:type x: Block
|
|
||||||
"""
|
|
||||||
return 1 / (1 + triton.language.exp(-x))
|
return 1 / (1 + triton.language.exp(-x))
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
@_add_math_1arg_docstr("softmax")
|
||||||
def softmax(x):
|
def softmax(x):
|
||||||
"""
|
|
||||||
Computes the element-wise softmax of :code:`x`.
|
|
||||||
|
|
||||||
:param x: the input block
|
|
||||||
:type x: Block
|
|
||||||
"""
|
|
||||||
z = x - triton.language.max(x, 0)
|
z = x - triton.language.max(x, 0)
|
||||||
num = triton.language.exp(z)
|
num = triton.language.exp(z)
|
||||||
den = triton.language.sum(num, 0)
|
den = triton.language.sum(num, 0)
|
||||||
|
@@ -322,7 +322,7 @@ else:
|
|||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||||
x_vals=[
|
x_vals=[
|
||||||
128 * i for i in range(1, 33)
|
128 * i for i in range(2, 33)
|
||||||
], # different possible values for `x_name`
|
], # different possible values for `x_name`
|
||||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||||
# possible values for `line_arg``
|
# possible values for `line_arg``
|
||||||
|
Reference in New Issue
Block a user