[FRONTEND] Complete rewrite of the runtime (#644)

This PR completely rewrites the runtime of Triton to be more lean and
clearly separate the compilation step from the just-in-time caching logic.
This should substantially reduce launch overhead.
This commit is contained in:
Philippe Tillet
2022-09-18 08:51:48 -07:00
committed by GitHub
parent 889d9e34a1
commit 4a77dfb042
17 changed files with 1198 additions and 780 deletions

View File

@@ -574,6 +574,19 @@ void init_triton_codegen(py::module &&m) {
assert(backend == ROCM);
return hip_load_binary(name, asm_map, n_shared_bytes, dev);
}, py::return_value_policy::take_ownership);
struct InstanceDescriptor
{
std::unordered_set<int> divisibleBy16;
std::unordered_set<int> equalTo1;
};
py::class_<InstanceDescriptor>(m, "instance_descriptor")
.def(py::init<>())
.def(py::init<std::unordered_set<int>, std::unordered_set<int>>())
.def_readonly("divisible_by_16", &InstanceDescriptor::divisibleBy16)
.def_readonly("equal_to_1", &InstanceDescriptor::equalTo1);
}
@@ -758,10 +771,11 @@ void init_triton_ir(py::module &&m) {
.def("get", &ir::struct_type::get, ret::reference)
.def_property_readonly("num_types", &ir::struct_type::get_num_types);
py::class_<ir::module>(m, "module")
py::class_<ir::module>(m, "module", py::dynamic_attr())
.def(py::init<std::string, ir::builder &>())
.def("has_function", &ir::module::has_function)
.def("get_function", &ir::module::get_function, ret::reference)
.def("get_functions", &ir::module::get_function_list, ret::reference)
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
.def("print", [](ir::module *self) {
self->print(std::cout);

View File

@@ -11,7 +11,7 @@ from numpy.random import RandomState
import triton
import triton._C.libtriton.triton as _triton
import triton.language as tl
from triton.code_gen import JITFunction, TensorWrapper, reinterpret
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
@@ -273,7 +273,7 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
elif (op in ('%', '/') and
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
with pytest.raises(triton.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
else:
@@ -311,7 +311,7 @@ def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
else:
numpy_expr = None
if 'float' in dtype_x + dtype_y:
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
with pytest.raises(triton.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
# The CompilationError must have been caused by a C++ exception with this text.
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
@@ -500,7 +500,7 @@ def test_index1d(expr, dtype_str, device='cuda'):
def catch_compilation_error(kernel):
try:
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
except triton.code_gen.CompilationError as e:
except triton.CompilationError as e:
np.testing.assert_(True)
except BaseException:
np.testing.assert_(False)
@@ -1209,7 +1209,7 @@ def test_load_cache_modifier(cache):
assert 'ld.global.cg' not in ptx
@pytest.mark.parametrize("N", [8, 10, 11, 1024])
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
def test_vectorization(N):
src = torch.empty(1024, device='cuda')
dst = torch.empty(1024, device='cuda')
@@ -1221,10 +1221,8 @@ def test_vectorization(N):
tl.store(dst + offsets, x, mask=offsets < N)
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
ptx = pgm.asm["ptx"]
if N % 4 == 0:
if N % 16 == 0:
assert "ld.global.v4.b32" in ptx
elif N % 2 == 0:
assert "ld.global.v2.b32" in ptx
else:
assert "ld.global.b32" in ptx
# triton.testing.assert_almost_equal(dst, src[:N])
@@ -1292,7 +1290,7 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
def cache_hook(*args, **kwargs):
nonlocal spec_type
spec_type = kwargs["compile"]["arg_types"][0][1]
spec_type = kwargs["compile"]["signature"][0]
JITFunction.cache_hook = cache_hook
@triton.jit
@@ -1319,7 +1317,7 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
x = torch.tensor([3.14159], device='cuda')
if overflow:
with pytest.raises(RuntimeError, match='integer overflow'):
with pytest.raises(OverflowError):
kernel[(1, )](value, x)
else:
kernel[(1, )](value, x)

View File

@@ -78,10 +78,9 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
kernel = triton.ops._matmul.kernel
decorators = kernel.kernel_decorators
kernel.kernel_decorators = []
triton.autotune(configs, [])(kernel)
kernel.kernel_decorators += decorators[1:]
kernel.configs = configs
# kernel.run = kernel.run.run.run
# get matrix shape
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N

View File

@@ -7,7 +7,7 @@ import torch
import triton
import triton.language as tl
from triton.code_gen import JITFunction
from triton.runtime.jit import JITFunction
tmpdir = ".tmp"
@@ -99,16 +99,16 @@ def test_specialize(mode):
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
target = {'enable': 5, 'disable': 1}[mode]
target = {'enable': 3, 'disable': 1}[mode]
for i in [1, 2, 4, 8, 16, 32]:
function[(1,)](x, i, BLOCK=512)
assert counter == target
@pytest.mark.parametrize("value, value_type", [
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
(-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@@ -120,14 +120,14 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
def get_cache_str(*args, **kwargs):
nonlocal cache_str
cache_str = kwargs['key'].split('-')
triton.code_gen.JITFunction.cache_hook = get_cache_str
cache_str = kwargs["repr"]
triton.JITFunction.cache_hook = get_cache_str
reset_tmp_dir()
x = torch.tensor([3.14159], device='cuda')
kernel[(1, )](value, x)
triton.code_gen.JITFunction.cache_hook = None
triton.JITFunction.cache_hook = None
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str)
spec_type = None if cache_str_match is None else cache_str_match.group(1)
assert spec_type == value_type

View File

@@ -6,9 +6,10 @@ __version__ = '2.0.0'
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
# submodules
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \
JITFunction, Config, Autotuner, reinterpret
from .utils import *
from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface
from .runtime.jit import jit
from .compiler import compile, CompilationError
from . import language
from . import code_gen
from . import testing
from . import ops

File diff suppressed because it is too large Load Diff

View File

@@ -420,7 +420,7 @@ class tensor:
self.numel = 1
for s in self.shape:
self.numel *= s
is_pow2 = (self.numel and (not(self.numel & (self.numel - 1))))
is_pow2 = (self.numel and (not (self.numel & (self.numel - 1))))
if not is_pow2:
raise ValueError("Triton tensors must have a power-of-two number of elements")
self.numel = constexpr(self.numel)

View File

@@ -18,8 +18,8 @@ def num_warps(n):
@triton.jit
def _blocksparse_softmax_fwd(
Out, A, stride_xz, LUT,
R, extent, stride_zr, stride_hr, # relative attention
Out, A, LUT, R, stride_xz,
extent, stride_zr, stride_hr, # relative attention
scale, is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
@@ -164,8 +164,8 @@ class _softmax(torch.autograd.Function):
# enqueue kernel
out = torch.empty_like(a)
_blocksparse_softmax_fwd[grid](
out, a, a.stride(0), lut,
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
out, a, lut, rel_logits, a.stride(0),
rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
scale,
is_causal,
BLOCK_SIZE=block,

View File

@@ -26,9 +26,6 @@ def get_configs_io_bound():
return configs
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
@@ -59,6 +56,9 @@ def get_configs_io_bound():
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,

View File

@@ -0,0 +1,2 @@
from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401
from .jit import JITFunction, KernelInterface, version_key # noqa: F401

View File

@@ -0,0 +1,204 @@
from __future__ import annotations
import builtins
import time
from typing import Dict
from ..testing import do_bench
from .jit import KernelInterface
class Autotuner(KernelInterface):
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
'''
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
'''
if not configs:
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.cache = dict()
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'early_config_prune' in prune_configs_by:
early_config_prune = prune_configs_by['early_config_prune']
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return do_bench(kernel_call)
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple([args[i] for i in self.key_idx])
if key not in self.cache:
# prune configs
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
class Config:
"""
An object that represents a possible kernel configuration for the auto-tuner to try.
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
:type meta: dict[Str, Any]
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
`num_warps=8`, then each kernel instance will be automatically parallelized to
cooperatively execute using `8 * 32 = 256` threads.
:type num_warps: int
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
:type num_stages: int
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
function are args.
"""
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
self.kwargs = kwargs
self.num_warps = num_warps
self.num_stages = num_stages
self.pre_hook = pre_hook
def __str__(self):
res = []
for k, v in self.kwargs.items():
res.append(f'{k}: {v}')
res.append(f'num_warps: {self.num_warps}')
res.append(f'num_stages: {self.num_stages}')
return ', '.join(res)
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
return decorator
class Heuristics(KernelInterface):
def __init__(self, fn, arg_names, values) -> None:
self.fn = fn
self.values = values
self.arg_names = arg_names
def run(self, *args, **kwargs):
for v, heur in self.values.items():
kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
return self.fn.run(*args, **kwargs)
def heuristics(values):
"""
Decorator for specifying how the values of certain meta-parameters may be computed.
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
.. highlight:: python
.. code-block:: python
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
each such function takes a list of positional arguments as input.
.type values: dict[str, Callable[[list[Any]], Any]]
"""
def decorator(fn):
return Heuristics(fn, fn.arg_names, values)
return decorator

View File

@@ -0,0 +1,415 @@
from __future__ import annotations, division
import ast
import functools
import hashlib
import inspect
import os
import subprocess
import textwrap
from collections import namedtuple
import torch
import triton
try:
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
except ImportError:
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
# -----------------------------------------------------------------------------
# Dependencies Finder
# -----------------------------------------------------------------------------
class DependenciesFinder(ast.NodeVisitor):
"""
This AST visitor is used to find dependencies of a JITFunction. This can
be used to invalidate a JITFunction's hash when its source code -- or
that of its dependencies -- changes.
"""
def __init__(self, globals, src) -> None:
super().__init__()
self.ret = hashlib.md5(src.encode("utf-8")).hexdigest()
self.globals = globals
def visit_Name(self, node):
return self.globals.get(node.id, None)
def visit_Attribute(self, node):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or lhs is triton:
return None
return getattr(lhs, node.attr)
def visit_Call(self, node):
func = self.visit(node.func)
if func is None:
return
if inspect.isbuiltin(func):
return
if func.__module__ and func.__module__.startswith('triton.'):
return
assert isinstance(func, JITFunction)
if func.hash is None:
tree = ast.parse(func.src)
finder = DependenciesFinder(func.__globals__, func.src)
finder.visit(tree)
func.hash = finder.ret
self.ret = (self.ret + func.hash).encode("utf-8")
self.ret = hashlib.md5(self.ret).hexdigest()
# -----------------------------------------------------------------------------
# JITFunction
# -----------------------------------------------------------------------------
@functools.lru_cache()
def version_key():
import pkgutil
contents = []
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
with open(triton.compiler.__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# backend
with open(triton._C.libtriton.__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# language
language_path = os.path.join(*triton.__path__, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# ptxas version
try:
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
except Exception:
ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
class KernelInterface:
def __getitem__(self, grid):
"""
A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
def launcher(*args, **kwargs):
return self.run(*args, grid=grid, **kwargs)
return launcher
class JITFunction(KernelInterface):
cache_hook = None
divisibility = 16
@staticmethod
def _key_of(arg):
if hasattr(arg, "dtype"):
return arg.dtype
elif isinstance(arg, bool):
return "i1"
elif isinstance(arg, int):
if -2**31 <= arg and arg <= 2**31 - 1:
return "i32"
elif 2**31 <= arg and arg <= 2**32 - 1:
return "u32"
elif 2**63 <= arg and arg <= 2**64 - 1:
return "u64"
else:
return "i64"
elif isinstance(arg, float):
return 'fp32'
elif arg is None:
return None
else:
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
@staticmethod
def _spec_of(arg):
if hasattr(arg, "data_ptr"):
return (arg.data_ptr() % JITFunction.divisibility == 0)
elif isinstance(arg, int):
return (arg % 16 == 0, arg == 1)
return (arg is None, )
def _get_config(self, *args):
def is_divisible_by_16(x):
if hasattr(x, "data_ptr"):
return x.data_ptr() % JITFunction.divisibility == 0
elif isinstance(x, int):
return x % JITFunction.divisibility == 0
if x is None:
return True
return False
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
@staticmethod
def _type_of(key):
if isinstance(key, (torch.dtype, triton.language.dtype)):
ty = {
torch.bool: 'i1',
torch.float16: 'fp16',
torch.bfloat16: 'bf16',
torch.float32: 'fp32',
torch.float64: 'fp64',
torch.uint8: 'u8',
torch.int8: 'i8',
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
triton.language.uint8: 'u8',
triton.language.uint16: 'u16',
triton.language.uint32: 'u32',
triton.language.uint64: 'u64',
triton.language.float8: 'fp8',
}[key]
return f'*{ty}'
if key is None:
return '*i8'
assert isinstance(key, str)
return key
def _make_signature(self, sig_key):
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
return signature
def _make_constants(self, constexpr_key):
constants = {i: k for i, k in zip(self.constexprs, constexpr_key)}
return constants
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
key = str(key)
class LegacyCompiler:
def __init__(self, module, name):
self.module = module
self.name = name
pass
kwargs = dict(signature=signature, device=device, constants=constants,
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
configs=configs)
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
def _make_launcher(self):
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
args = ', '.join(regular_args)
# cache key for regular argument type
sig_keys = ', '.join([f'_key_of({arg})' for arg in regular_args])
# cache key for constexpr argument values
constexpr_keys = ', '.join(constexpr_args)
# cache key for argument specialization
specializations = []
for i, arg in enumerate(regular_args):
if i in self.do_not_specialize:
continue
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") '
f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) '
f'else (False,)']
spec_keys = ', '.join(specializations)
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
src = f"""
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None):
sig_key = {sig_keys},
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()}
key = (version_key, sig_key, constexpr_key, spec_key)
if not extern_libs is None:
key = (key, tuple(extern_libs.items()))
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
if callable(grid):
grid = grid({{{grid_args}}})
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
device = torch.cuda.current_device()
torch.cuda.set_device(device)
if stream is None:
stream = get_cuda_stream(device)
try:
bin = cache[key]
bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args})
return bin
# kernel not cached -- compile
except KeyError:
# build dict of constant values
args = [{args}]
configs = self._get_config(*args),
constants = self._make_constants(constexpr_key)
constants.update({{i: None for i, arg in enumerate(args) if arg is None}})
constants.update({{i: 1 for i in configs[0].equal_to_1}})
# build kernel signature -- doesn't include specialized arguments
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
device = 0
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args)
self.cache[key] = bin
return bin
return None
"""
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
"cache": self.cache, "triton": triton, "torch": torch}
exec(src, scope)
return scope[self.fn.__name__]
def __init__(self, fn, version=None, do_not_specialize=None):
self.fn = fn
self.module = fn.__module__
self.version = version
# function signature information
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.has_defaults = any([v.default != inspect._empty for v in signature.parameters.values()])
# specialization hints
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = set([self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize])
# function source code (without decorators)
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
# cache of just-in-time compiled kernels
self.cache = dict()
self.hash = None
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None
# annotations
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
self.__annotations__ = fn.__annotations__
# index of constexprs
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
# launcher
self.run = self._make_launcher()
# re-use docs of wrapped function
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
self.__globals__ = fn.__globals__
self.__module__ = fn.__module__
@property
@functools.lru_cache()
def cache_key(self):
# TODO : hash should be attribute of `self`
if self.hash is None:
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
dependencies_finder.visit(self.parse())
self.hash = dependencies_finder.ret + version_key()
return self.hash
# we do not parse `src` in the constructor because
# the user might want to monkey-patch self.src dynamically.
# Our unit tests do this, for example.
def parse(self):
tree = ast.parse(self.src)
assert isinstance(tree, ast.Module)
assert len(tree.body) == 1
assert isinstance(tree.body[0], ast.FunctionDef)
return tree
def __call__(self, *args, **kwargs):
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
def __setattr__(self, name, value):
# - when kernel decorators change, cached kernel
# needs to be cleared
if name == 'kernel_decorators':
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
# - when `.src` attribute is set, cache path needs
# to be reinitialized
if name == 'src':
self.hash = None
JITFunction.cache_key.fget.cache_clear()
def __repr__(self):
return f"JITFunction({self.module}:{self.fn.__name__})"
# -----------------------------------------------------------------------------
# `jit` decorator
# -----------------------------------------------------------------------------
def jit(*args, **kwargs):
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* objects within the triton.language package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
if args:
assert len(args) == 1
assert callable(args[0])
return JITFunction(args[0], **kwargs)
else:
def decorator(fn):
return JITFunction(fn, **kwargs)
return decorator
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base
self.is_cuda = base.is_cuda
self.device = base.device
def data_ptr(self):
return self.base.data_ptr()
def __str__(self) -> str:
return f'TensorWrapper[{self.dtype}]({self.base})'
def reinterpret(tensor, dtype):
if isinstance(tensor, TensorWrapper):
if dtype == tensor.base.dtype:
# Reinterpreting to the original interpretation; return the base.
return tensor.base
else:
# Reinterpreting a wrapped tensor to a different type.
return TensorWrapper(tensor.base, dtype)
elif isinstance(tensor, torch.Tensor):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')

View File

@@ -7,7 +7,7 @@ from contextlib import contextmanager
import torch
import triton._C.libtriton.triton as _triton
from .code_gen import OutOfResources
from .compiler import OutOfResources
try:
import triton._C.libtriton.cutlass as _cutlass

48
python/triton/utils.py Normal file
View File

@@ -0,0 +1,48 @@
from __future__ import annotations
import torch
def cdiv(x, y):
return (x + y - 1) // y
def next_power_of_2(n):
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base
self.is_cuda = base.is_cuda
self.device = base.device
def data_ptr(self):
return self.base.data_ptr()
def __str__(self) -> str:
return f'TensorWrapper[{self.dtype}]({self.base})'
def reinterpret(tensor, dtype):
if isinstance(tensor, TensorWrapper):
if dtype == tensor.base.dtype:
# Reinterpreting to the original interpretation; return the base.
return tensor.base
else:
# Reinterpreting a wrapped tensor to a different type.
return TensorWrapper(tensor.base, dtype)
elif isinstance(tensor, torch.Tensor):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')