[FRONTEND] Complete rewrite of the runtime (#644)
This PR completely rewrites the runtime of Triton to be more lean and clearly separate the compilation step from the just-in-time caching logic. This should substantially reduce launch overhead.
This commit is contained in:
@@ -574,6 +574,19 @@ void init_triton_codegen(py::module &&m) {
|
||||
assert(backend == ROCM);
|
||||
return hip_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||
}, py::return_value_policy::take_ownership);
|
||||
|
||||
|
||||
struct InstanceDescriptor
|
||||
{
|
||||
std::unordered_set<int> divisibleBy16;
|
||||
std::unordered_set<int> equalTo1;
|
||||
};
|
||||
|
||||
py::class_<InstanceDescriptor>(m, "instance_descriptor")
|
||||
.def(py::init<>())
|
||||
.def(py::init<std::unordered_set<int>, std::unordered_set<int>>())
|
||||
.def_readonly("divisible_by_16", &InstanceDescriptor::divisibleBy16)
|
||||
.def_readonly("equal_to_1", &InstanceDescriptor::equalTo1);
|
||||
}
|
||||
|
||||
|
||||
@@ -758,10 +771,11 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get", &ir::struct_type::get, ret::reference)
|
||||
.def_property_readonly("num_types", &ir::struct_type::get_num_types);
|
||||
|
||||
py::class_<ir::module>(m, "module")
|
||||
py::class_<ir::module>(m, "module", py::dynamic_attr())
|
||||
.def(py::init<std::string, ir::builder &>())
|
||||
.def("has_function", &ir::module::has_function)
|
||||
.def("get_function", &ir::module::get_function, ret::reference)
|
||||
.def("get_functions", &ir::module::get_function_list, ret::reference)
|
||||
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
|
||||
.def("print", [](ir::module *self) {
|
||||
self->print(std::cout);
|
||||
|
@@ -11,7 +11,7 @@ from numpy.random import RandomState
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
import triton.language as tl
|
||||
from triton.code_gen import JITFunction, TensorWrapper, reinterpret
|
||||
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
|
||||
|
||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
@@ -273,7 +273,7 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
elif (op in ('%', '/') and
|
||||
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
|
||||
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
|
||||
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
||||
with pytest.raises(triton.CompilationError) as exc_info:
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
|
||||
else:
|
||||
@@ -311,7 +311,7 @@ def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
else:
|
||||
numpy_expr = None
|
||||
if 'float' in dtype_x + dtype_y:
|
||||
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
||||
with pytest.raises(triton.CompilationError) as exc_info:
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
|
||||
# The CompilationError must have been caused by a C++ exception with this text.
|
||||
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
|
||||
@@ -500,7 +500,7 @@ def test_index1d(expr, dtype_str, device='cuda'):
|
||||
def catch_compilation_error(kernel):
|
||||
try:
|
||||
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||
except triton.code_gen.CompilationError as e:
|
||||
except triton.CompilationError as e:
|
||||
np.testing.assert_(True)
|
||||
except BaseException:
|
||||
np.testing.assert_(False)
|
||||
@@ -1209,7 +1209,7 @@ def test_load_cache_modifier(cache):
|
||||
assert 'ld.global.cg' not in ptx
|
||||
|
||||
|
||||
@pytest.mark.parametrize("N", [8, 10, 11, 1024])
|
||||
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
|
||||
def test_vectorization(N):
|
||||
src = torch.empty(1024, device='cuda')
|
||||
dst = torch.empty(1024, device='cuda')
|
||||
@@ -1221,10 +1221,8 @@ def test_vectorization(N):
|
||||
tl.store(dst + offsets, x, mask=offsets < N)
|
||||
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
||||
ptx = pgm.asm["ptx"]
|
||||
if N % 4 == 0:
|
||||
if N % 16 == 0:
|
||||
assert "ld.global.v4.b32" in ptx
|
||||
elif N % 2 == 0:
|
||||
assert "ld.global.v2.b32" in ptx
|
||||
else:
|
||||
assert "ld.global.b32" in ptx
|
||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||
@@ -1292,7 +1290,7 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
|
||||
|
||||
def cache_hook(*args, **kwargs):
|
||||
nonlocal spec_type
|
||||
spec_type = kwargs["compile"]["arg_types"][0][1]
|
||||
spec_type = kwargs["compile"]["signature"][0]
|
||||
JITFunction.cache_hook = cache_hook
|
||||
|
||||
@triton.jit
|
||||
@@ -1319,7 +1317,7 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
|
||||
if overflow:
|
||||
with pytest.raises(RuntimeError, match='integer overflow'):
|
||||
with pytest.raises(OverflowError):
|
||||
kernel[(1, )](value, x)
|
||||
else:
|
||||
kernel[(1, )](value, x)
|
||||
|
@@ -78,10 +78,9 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
|
||||
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
|
||||
kernel = triton.ops._matmul.kernel
|
||||
decorators = kernel.kernel_decorators
|
||||
kernel.kernel_decorators = []
|
||||
triton.autotune(configs, [])(kernel)
|
||||
kernel.kernel_decorators += decorators[1:]
|
||||
kernel.configs = configs
|
||||
# kernel.run = kernel.run.run.run
|
||||
|
||||
# get matrix shape
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
|
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.code_gen import JITFunction
|
||||
from triton.runtime.jit import JITFunction
|
||||
|
||||
tmpdir = ".tmp"
|
||||
|
||||
@@ -99,16 +99,16 @@ def test_specialize(mode):
|
||||
reset_tmp_dir()
|
||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
|
||||
target = {'enable': 5, 'disable': 1}[mode]
|
||||
target = {'enable': 3, 'disable': 1}[mode]
|
||||
for i in [1, 2, 4, 8, 16, 32]:
|
||||
function[(1,)](x, i, BLOCK=512)
|
||||
assert counter == target
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
|
||||
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
|
||||
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
|
||||
(-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'),
|
||||
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
|
||||
@@ -120,14 +120,14 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
|
||||
|
||||
def get_cache_str(*args, **kwargs):
|
||||
nonlocal cache_str
|
||||
cache_str = kwargs['key'].split('-')
|
||||
triton.code_gen.JITFunction.cache_hook = get_cache_str
|
||||
cache_str = kwargs["repr"]
|
||||
triton.JITFunction.cache_hook = get_cache_str
|
||||
reset_tmp_dir()
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
kernel[(1, )](value, x)
|
||||
triton.code_gen.JITFunction.cache_hook = None
|
||||
triton.JITFunction.cache_hook = None
|
||||
|
||||
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
|
||||
cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str)
|
||||
spec_type = None if cache_str_match is None else cache_str_match.group(1)
|
||||
assert spec_type == value_type
|
||||
|
||||
|
@@ -6,9 +6,10 @@ __version__ = '2.0.0'
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
import torch
|
||||
# submodules
|
||||
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \
|
||||
JITFunction, Config, Autotuner, reinterpret
|
||||
from .utils import *
|
||||
from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface
|
||||
from .runtime.jit import jit
|
||||
from .compiler import compile, CompilationError
|
||||
from . import language
|
||||
from . import code_gen
|
||||
from . import testing
|
||||
from . import ops
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -420,7 +420,7 @@ class tensor:
|
||||
self.numel = 1
|
||||
for s in self.shape:
|
||||
self.numel *= s
|
||||
is_pow2 = (self.numel and (not(self.numel & (self.numel - 1))))
|
||||
is_pow2 = (self.numel and (not (self.numel & (self.numel - 1))))
|
||||
if not is_pow2:
|
||||
raise ValueError("Triton tensors must have a power-of-two number of elements")
|
||||
self.numel = constexpr(self.numel)
|
||||
|
@@ -18,8 +18,8 @@ def num_warps(n):
|
||||
|
||||
@triton.jit
|
||||
def _blocksparse_softmax_fwd(
|
||||
Out, A, stride_xz, LUT,
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
Out, A, LUT, R, stride_xz,
|
||||
extent, stride_zr, stride_hr, # relative attention
|
||||
scale, is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
@@ -164,8 +164,8 @@ class _softmax(torch.autograd.Function):
|
||||
# enqueue kernel
|
||||
out = torch.empty_like(a)
|
||||
_blocksparse_softmax_fwd[grid](
|
||||
out, a, a.stride(0), lut,
|
||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
||||
out, a, lut, rel_logits, a.stride(0),
|
||||
rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
||||
scale,
|
||||
is_causal,
|
||||
BLOCK_SIZE=block,
|
||||
|
@@ -26,9 +26,6 @@ def get_configs_io_bound():
|
||||
return configs
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
@@ -59,6 +56,9 @@ def get_configs_io_bound():
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _kernel(A, B, C, M, N, K,
|
||||
stride_am, stride_ak,
|
||||
|
2
python/triton/runtime/__init__.py
Normal file
2
python/triton/runtime/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401
|
||||
from .jit import JITFunction, KernelInterface, version_key # noqa: F401
|
204
python/triton/runtime/autotuner.py
Normal file
204
python/triton/runtime/autotuner.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from ..testing import do_bench
|
||||
from .jit import KernelInterface
|
||||
|
||||
|
||||
class Autotuner(KernelInterface):
|
||||
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
|
||||
'''
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
'''
|
||||
if not configs:
|
||||
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
|
||||
else:
|
||||
self.configs = configs
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
self.cache = dict()
|
||||
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||
self.hook = lambda args: 0
|
||||
if reset_to_zero is not None:
|
||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||
|
||||
def _hook(args):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
# prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||
if 'early_config_prune' in prune_configs_by:
|
||||
early_config_prune = prune_configs_by['early_config_prune']
|
||||
else:
|
||||
perf_model, top_k, early_config_prune = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.early_config_prune = early_config_prune
|
||||
self.fn = fn
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
# as kwargs and by the autotuner
|
||||
conflicts = meta.keys() & config.kwargs.keys()
|
||||
if conflicts:
|
||||
raise ValueError(
|
||||
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols."
|
||||
)
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(self.nargs)
|
||||
self.hook(args)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
return do_bench(kernel_call)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
if len(self.configs) > 1:
|
||||
key = tuple([args[i] for i in self.key_idx])
|
||||
if key not in self.cache:
|
||||
# prune configs
|
||||
pruned_configs = self.configs
|
||||
if self.early_config_prune:
|
||||
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||
if self.perf_model:
|
||||
top_k = self.configs_top_k
|
||||
if isinstance(top_k, float) and top_k <= 1.0:
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||
for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
self.hook(args)
|
||||
self.configs_timings = timings
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
self.best_config = config
|
||||
if config.pre_hook is not None:
|
||||
config.pre_hook(self.nargs)
|
||||
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
An object that represents a possible kernel configuration for the auto-tuner to try.
|
||||
|
||||
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
||||
:type meta: dict[Str, Any]
|
||||
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
|
||||
`num_warps=8`, then each kernel instance will be automatically parallelized to
|
||||
cooperatively execute using `8 * 32 = 256` threads.
|
||||
:type num_warps: int
|
||||
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
||||
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
||||
:type num_stages: int
|
||||
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
||||
function are args.
|
||||
"""
|
||||
|
||||
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
|
||||
self.kwargs = kwargs
|
||||
self.num_warps = num_warps
|
||||
self.num_stages = num_stages
|
||||
self.pre_hook = pre_hook
|
||||
|
||||
def __str__(self):
|
||||
res = []
|
||||
for k, v in self.kwargs.items():
|
||||
res.append(f'{k}: {v}')
|
||||
res.append(f'num_warps: {self.num_warps}')
|
||||
res.append(f'num_stages: {self.num_stages}')
|
||||
return ', '.join(res)
|
||||
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||
],
|
||||
key=['x_size'] # the two above configs will be evaluated anytime
|
||||
# the value of x_size changes
|
||||
)
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
|
||||
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||
reset the value of the provided tensor to `zero` before running any configuration.
|
||||
|
||||
:param configs: a list of :code:`triton.Config` objects
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
:type key: list[str]
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
"""
|
||||
def decorator(fn):
|
||||
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Heuristics(KernelInterface):
|
||||
|
||||
def __init__(self, fn, arg_names, values) -> None:
|
||||
self.fn = fn
|
||||
self.values = values
|
||||
self.arg_names = arg_names
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
for v, heur in self.values.items():
|
||||
kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
|
||||
return self.fn.run(*args, **kwargs)
|
||||
|
||||
|
||||
def heuristics(values):
|
||||
"""
|
||||
Decorator for specifying how the values of certain meta-parameters may be computed.
|
||||
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
|
||||
|
||||
|
||||
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
||||
each such function takes a list of positional arguments as input.
|
||||
.type values: dict[str, Callable[[list[Any]], Any]]
|
||||
"""
|
||||
def decorator(fn):
|
||||
return Heuristics(fn, fn.arg_names, values)
|
||||
|
||||
return decorator
|
415
python/triton/runtime/jit.py
Normal file
415
python/triton/runtime/jit.py
Normal file
@@ -0,0 +1,415 @@
|
||||
from __future__ import annotations, division
|
||||
|
||||
import ast
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import subprocess
|
||||
import textwrap
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
try:
|
||||
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
||||
except ImportError:
|
||||
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dependencies Finder
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DependenciesFinder(ast.NodeVisitor):
|
||||
"""
|
||||
This AST visitor is used to find dependencies of a JITFunction. This can
|
||||
be used to invalidate a JITFunction's hash when its source code -- or
|
||||
that of its dependencies -- changes.
|
||||
"""
|
||||
|
||||
def __init__(self, globals, src) -> None:
|
||||
super().__init__()
|
||||
self.ret = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||
self.globals = globals
|
||||
|
||||
def visit_Name(self, node):
|
||||
return self.globals.get(node.id, None)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
lhs = self.visit(lhs.value)
|
||||
if lhs is None or lhs is triton:
|
||||
return None
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
def visit_Call(self, node):
|
||||
func = self.visit(node.func)
|
||||
if func is None:
|
||||
return
|
||||
if inspect.isbuiltin(func):
|
||||
return
|
||||
if func.__module__ and func.__module__.startswith('triton.'):
|
||||
return
|
||||
assert isinstance(func, JITFunction)
|
||||
if func.hash is None:
|
||||
tree = ast.parse(func.src)
|
||||
finder = DependenciesFinder(func.__globals__, func.src)
|
||||
finder.visit(tree)
|
||||
func.hash = finder.ret
|
||||
self.ret = (self.ret + func.hash).encode("utf-8")
|
||||
self.ret = hashlib.md5(self.ret).hexdigest()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# JITFunction
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def version_key():
|
||||
import pkgutil
|
||||
contents = []
|
||||
# frontend
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
with open(triton.compiler.__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# backend
|
||||
with open(triton._C.libtriton.__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# language
|
||||
language_path = os.path.join(*triton.__path__, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# ptxas version
|
||||
try:
|
||||
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
|
||||
except Exception:
|
||||
ptxas_version = ''
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
class KernelInterface:
|
||||
|
||||
def __getitem__(self, grid):
|
||||
"""
|
||||
A JIT function is launched with: fn[grid](*args, **kwargs).
|
||||
Hence JITFunction.__getitem__ returns a callable proxy that
|
||||
memorizes the grid.
|
||||
"""
|
||||
def launcher(*args, **kwargs):
|
||||
return self.run(*args, grid=grid, **kwargs)
|
||||
return launcher
|
||||
|
||||
|
||||
class JITFunction(KernelInterface):
|
||||
|
||||
cache_hook = None
|
||||
divisibility = 16
|
||||
|
||||
@staticmethod
|
||||
def _key_of(arg):
|
||||
if hasattr(arg, "dtype"):
|
||||
return arg.dtype
|
||||
elif isinstance(arg, bool):
|
||||
return "i1"
|
||||
elif isinstance(arg, int):
|
||||
if -2**31 <= arg and arg <= 2**31 - 1:
|
||||
return "i32"
|
||||
elif 2**31 <= arg and arg <= 2**32 - 1:
|
||||
return "u32"
|
||||
elif 2**63 <= arg and arg <= 2**64 - 1:
|
||||
return "u64"
|
||||
else:
|
||||
return "i64"
|
||||
elif isinstance(arg, float):
|
||||
return 'fp32'
|
||||
elif arg is None:
|
||||
return None
|
||||
else:
|
||||
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
|
||||
|
||||
@staticmethod
|
||||
def _spec_of(arg):
|
||||
if hasattr(arg, "data_ptr"):
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
||||
elif isinstance(arg, int):
|
||||
return (arg % 16 == 0, arg == 1)
|
||||
return (arg is None, )
|
||||
|
||||
def _get_config(self, *args):
|
||||
def is_divisible_by_16(x):
|
||||
if hasattr(x, "data_ptr"):
|
||||
return x.data_ptr() % JITFunction.divisibility == 0
|
||||
elif isinstance(x, int):
|
||||
return x % JITFunction.divisibility == 0
|
||||
if x is None:
|
||||
return True
|
||||
return False
|
||||
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||
equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
|
||||
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
|
||||
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
|
||||
|
||||
@staticmethod
|
||||
def _type_of(key):
|
||||
if isinstance(key, (torch.dtype, triton.language.dtype)):
|
||||
ty = {
|
||||
torch.bool: 'i1',
|
||||
torch.float16: 'fp16',
|
||||
torch.bfloat16: 'bf16',
|
||||
torch.float32: 'fp32',
|
||||
torch.float64: 'fp64',
|
||||
torch.uint8: 'u8',
|
||||
torch.int8: 'i8',
|
||||
torch.int16: 'i16',
|
||||
torch.int32: 'i32',
|
||||
torch.int64: 'i64',
|
||||
|
||||
triton.language.uint8: 'u8',
|
||||
triton.language.uint16: 'u16',
|
||||
triton.language.uint32: 'u32',
|
||||
triton.language.uint64: 'u64',
|
||||
triton.language.float8: 'fp8',
|
||||
}[key]
|
||||
return f'*{ty}'
|
||||
if key is None:
|
||||
return '*i8'
|
||||
assert isinstance(key, str)
|
||||
return key
|
||||
|
||||
def _make_signature(self, sig_key):
|
||||
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
|
||||
return signature
|
||||
|
||||
def _make_constants(self, constexpr_key):
|
||||
constants = {i: k for i, k in zip(self.constexprs, constexpr_key)}
|
||||
return constants
|
||||
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
def __init__(self, module, name):
|
||||
self.module = module
|
||||
self.name = name
|
||||
pass
|
||||
|
||||
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
|
||||
configs=configs)
|
||||
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
||||
|
||||
def _make_launcher(self):
|
||||
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
|
||||
args = ', '.join(regular_args)
|
||||
# cache key for regular argument type
|
||||
sig_keys = ', '.join([f'_key_of({arg})' for arg in regular_args])
|
||||
# cache key for constexpr argument values
|
||||
constexpr_keys = ', '.join(constexpr_args)
|
||||
# cache key for argument specialization
|
||||
specializations = []
|
||||
for i, arg in enumerate(regular_args):
|
||||
if i in self.do_not_specialize:
|
||||
continue
|
||||
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") '
|
||||
f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) '
|
||||
f'else (False,)']
|
||||
spec_keys = ', '.join(specializations)
|
||||
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
|
||||
src = f"""
|
||||
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None):
|
||||
sig_key = {sig_keys},
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()}
|
||||
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()}
|
||||
key = (version_key, sig_key, constexpr_key, spec_key)
|
||||
if not extern_libs is None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
||||
if callable(grid):
|
||||
grid = grid({{{grid_args}}})
|
||||
grid_size = len(grid)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
device = torch.cuda.current_device()
|
||||
torch.cuda.set_device(device)
|
||||
if stream is None:
|
||||
stream = get_cuda_stream(device)
|
||||
try:
|
||||
bin = cache[key]
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args})
|
||||
return bin
|
||||
# kernel not cached -- compile
|
||||
except KeyError:
|
||||
# build dict of constant values
|
||||
args = [{args}]
|
||||
configs = self._get_config(*args),
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({{i: None for i, arg in enumerate(args) if arg is None}})
|
||||
constants.update({{i: 1 for i in configs[0].equal_to_1}})
|
||||
# build kernel signature -- doesn't include specialized arguments
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
||||
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
|
||||
# build stub signature -- includes arguments that are specialized
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
device = 0
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args)
|
||||
self.cache[key] = bin
|
||||
return bin
|
||||
return None
|
||||
"""
|
||||
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
|
||||
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
|
||||
"cache": self.cache, "triton": triton, "torch": torch}
|
||||
exec(src, scope)
|
||||
return scope[self.fn.__name__]
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.version = version
|
||||
# function signature information
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.has_defaults = any([v.default != inspect._empty for v in signature.parameters.values()])
|
||||
# specialization hints
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = set([self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize])
|
||||
# function source code (without decorators)
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
# cache of just-in-time compiled kernels
|
||||
self.cache = dict()
|
||||
self.hash = None
|
||||
# JITFunction can be instantiated as kernel
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
self.kernel = None
|
||||
# annotations
|
||||
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
||||
self.__annotations__ = fn.__annotations__
|
||||
# index of constexprs
|
||||
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
|
||||
# launcher
|
||||
self.run = self._make_launcher()
|
||||
# re-use docs of wrapped function
|
||||
self.__doc__ = fn.__doc__
|
||||
self.__name__ = fn.__name__
|
||||
self.__globals__ = fn.__globals__
|
||||
self.__module__ = fn.__module__
|
||||
|
||||
@property
|
||||
@functools.lru_cache()
|
||||
def cache_key(self):
|
||||
# TODO : hash should be attribute of `self`
|
||||
if self.hash is None:
|
||||
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
|
||||
dependencies_finder.visit(self.parse())
|
||||
self.hash = dependencies_finder.ret + version_key()
|
||||
return self.hash
|
||||
|
||||
# we do not parse `src` in the constructor because
|
||||
# the user might want to monkey-patch self.src dynamically.
|
||||
# Our unit tests do this, for example.
|
||||
def parse(self):
|
||||
tree = ast.parse(self.src)
|
||||
assert isinstance(tree, ast.Module)
|
||||
assert len(tree.body) == 1
|
||||
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||
return tree
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
# - when kernel decorators change, cached kernel
|
||||
# needs to be cleared
|
||||
if name == 'kernel_decorators':
|
||||
self.kernel = None
|
||||
super(JITFunction, self).__setattr__(name, value)
|
||||
# - when `.src` attribute is set, cache path needs
|
||||
# to be reinitialized
|
||||
if name == 'src':
|
||||
self.hash = None
|
||||
JITFunction.cache_key.fget.cache_clear()
|
||||
|
||||
def __repr__(self):
|
||||
return f"JITFunction({self.module}:{self.fn.__name__})"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# `jit` decorator
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def jit(*args, **kwargs):
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
|
||||
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||
|
||||
:note: This function will be compiled and run on the GPU. It will only have access to:
|
||||
|
||||
* python primitives,
|
||||
* objects within the triton.language package,
|
||||
* arguments to this function,
|
||||
* other jit'd functions
|
||||
|
||||
:param fn: the function to be jit-compiled
|
||||
:type fn: Callable
|
||||
"""
|
||||
if args:
|
||||
assert len(args) == 1
|
||||
assert callable(args[0])
|
||||
return JITFunction(args[0], **kwargs)
|
||||
else:
|
||||
def decorator(fn):
|
||||
return JITFunction(fn, **kwargs)
|
||||
return decorator
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
self.base = base
|
||||
self.is_cuda = base.is_cuda
|
||||
self.device = base.device
|
||||
|
||||
def data_ptr(self):
|
||||
return self.base.data_ptr()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
if isinstance(tensor, TensorWrapper):
|
||||
if dtype == tensor.base.dtype:
|
||||
# Reinterpreting to the original interpretation; return the base.
|
||||
return tensor.base
|
||||
else:
|
||||
# Reinterpreting a wrapped tensor to a different type.
|
||||
return TensorWrapper(tensor.base, dtype)
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
@@ -7,7 +7,7 @@ from contextlib import contextmanager
|
||||
import torch
|
||||
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .code_gen import OutOfResources
|
||||
from .compiler import OutOfResources
|
||||
|
||||
try:
|
||||
import triton._C.libtriton.cutlass as _cutlass
|
||||
|
48
python/triton/utils.py
Normal file
48
python/triton/utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def cdiv(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
"""Return the smallest power of 2 greater than or equal to n"""
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
self.base = base
|
||||
self.is_cuda = base.is_cuda
|
||||
self.device = base.device
|
||||
|
||||
def data_ptr(self):
|
||||
return self.base.data_ptr()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
if isinstance(tensor, TensorWrapper):
|
||||
if dtype == tensor.base.dtype:
|
||||
# Reinterpreting to the original interpretation; return the base.
|
||||
return tensor.base
|
||||
else:
|
||||
# Reinterpreting a wrapped tensor to a different type.
|
||||
return TensorWrapper(tensor.base, dtype)
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
Reference in New Issue
Block a user