[DOCS] Various improvements (#224)
- Added docstr for autotune, Config, heuristics - Added docstr for atomics - Hiding internal _builder argument used for built-in language primitives - Re-factor docstr to use common templates between similar functions.
This commit is contained in:
@@ -193,11 +193,11 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
kws = dict()
|
||||
|
||||
if self.is_triton_object(lhs):
|
||||
kws['builder'] = self.builder
|
||||
kws['_builder'] = self.builder
|
||||
ret = getattr(lhs, fn)(rhs, **kws)
|
||||
if ret is NotImplemented:
|
||||
if self.is_triton_object(rhs):
|
||||
kws['builder'] = self.builder
|
||||
kws['_builder'] = self.builder
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
ret = getattr(rhs, fn)(lhs, **kws)
|
||||
return ret
|
||||
@@ -260,10 +260,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.IsNot: '__ne__',
|
||||
}[type(node.ops[0])]
|
||||
if self.is_triton_object(lhs):
|
||||
return getattr(lhs, fn)(rhs, builder=self.builder)
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif self.is_triton_object(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, builder=self.builder)
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
return getattr(lhs, fn)(rhs)
|
||||
|
||||
@@ -275,7 +275,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.Invert: '__invert__',
|
||||
}[type(node.op)]
|
||||
if self.is_triton_object(op):
|
||||
return getattr(op, fn)(builder=self.builder)
|
||||
return getattr(op, fn)(_builder=self.builder)
|
||||
return getattr(op, fn)()
|
||||
|
||||
def visit_While(self, node):
|
||||
@@ -308,7 +308,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
lhs = self.visit(node.value)
|
||||
slices = self.visit(node.slice)
|
||||
if self.is_triton_object(lhs):
|
||||
return lhs.__getitem__(slices, builder=self.builder)
|
||||
return lhs.__getitem__(slices, _builder=self.builder)
|
||||
return lhs[slices]
|
||||
|
||||
def visit_ExtSlice(self, node):
|
||||
@@ -331,7 +331,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
|
||||
self.visit(pos_cond_node),\
|
||||
self.visit(neg_cond_node),\
|
||||
builder=self.builder)
|
||||
_builder=self.builder)
|
||||
#cond_node = neg_cond_node
|
||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||
# code generation
|
||||
@@ -385,7 +385,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return fn(*args, generator=self, **kws)
|
||||
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.language:
|
||||
return fn(*args, builder=self.builder, **kws)
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
return fn(*args, **kws)
|
||||
|
||||
def visit_Num(self, node):
|
||||
@@ -714,6 +714,19 @@ class JITFunction:
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
An object that represents a possible kernel configuration for the auto-tuner to try.
|
||||
|
||||
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
||||
:type meta: dict[Str, Any]
|
||||
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
|
||||
`num_warps=8`, then each kernel instance will be automatically parallelized to
|
||||
cooperatively execute using `8 * 32 = 256` threads.
|
||||
:type num_warps: int
|
||||
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
||||
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
||||
:type num_stages: int
|
||||
"""
|
||||
def __init__(self, meta, num_warps=4, num_stages=2):
|
||||
self.meta = meta
|
||||
self.num_warps = num_warps
|
||||
@@ -721,6 +734,35 @@ class Config:
|
||||
|
||||
|
||||
def autotune(configs, key, reset_to_zero=None):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||
],
|
||||
key=['x_size'] # the two above configs will be evaluated anytime
|
||||
# the value of x_size changes
|
||||
)
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
|
||||
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||
reset the value of the provided tensor to `zero` before running any configuration.
|
||||
|
||||
:param configs: a list of :code:`triton.Config` objects
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
:type key: list[str]
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
"""
|
||||
def decorator(fn):
|
||||
def wrapper(kernel):
|
||||
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero)
|
||||
@@ -732,6 +774,23 @@ def autotune(configs, key, reset_to_zero=None):
|
||||
|
||||
|
||||
def heuristics(values):
|
||||
"""
|
||||
Decorator for specifying how the values of certain meta-parameters may be computed.
|
||||
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
|
||||
|
||||
|
||||
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
||||
each such function takes a list of positional arguments as input.
|
||||
.type values: dict[str, Callable[[list[Any]], Any]]
|
||||
"""
|
||||
def decorator(fn):
|
||||
def wrapper(kernel):
|
||||
def fun(*args, **meta):
|
||||
@@ -767,6 +826,8 @@ def jit(fn):
|
||||
return JITFunction(fn)
|
||||
|
||||
|
||||
######
|
||||
|
||||
def cdiv(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
Reference in New Issue
Block a user