[FRONTEND] Complete rewrite of the runtime (#644)
This PR completely rewrites the runtime of Triton to be more lean and clearly separate the compilation step from the just-in-time caching logic. This should substantially reduce launch overhead.
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
#include "llvm/IR/LLVMContext.h"
|
#include "llvm/IR/LLVMContext.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
|
@@ -87,7 +87,6 @@ public:
|
|||||||
|
|
||||||
// Functions
|
// Functions
|
||||||
const functions_list_t &get_function_list() const { return functions_; }
|
const functions_list_t &get_function_list() const { return functions_; }
|
||||||
functions_list_t &get_function_list() { return functions_; }
|
|
||||||
function *get_function(const std::string& name) {
|
function *get_function(const std::string& name) {
|
||||||
if(symbols_.find(name) == symbols_.end())
|
if(symbols_.find(name) == symbols_.end())
|
||||||
throw std::runtime_error("function " + name + " is not declared");
|
throw std::runtime_error("function " + name + " is not declared");
|
||||||
|
@@ -106,11 +106,11 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
|
|||||||
// run passes
|
// run passes
|
||||||
inliner.run(ir);
|
inliner.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
// ir.print(std::cout);
|
|
||||||
peephole.run(ir);
|
peephole.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
pipeline.run(ir);
|
pipeline.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
|
// ir.print(std::cout);
|
||||||
disassociate.run(ir);
|
disassociate.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
align.run(ir);
|
align.run(ir);
|
||||||
|
@@ -574,6 +574,19 @@ void init_triton_codegen(py::module &&m) {
|
|||||||
assert(backend == ROCM);
|
assert(backend == ROCM);
|
||||||
return hip_load_binary(name, asm_map, n_shared_bytes, dev);
|
return hip_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||||
}, py::return_value_policy::take_ownership);
|
}, py::return_value_policy::take_ownership);
|
||||||
|
|
||||||
|
|
||||||
|
struct InstanceDescriptor
|
||||||
|
{
|
||||||
|
std::unordered_set<int> divisibleBy16;
|
||||||
|
std::unordered_set<int> equalTo1;
|
||||||
|
};
|
||||||
|
|
||||||
|
py::class_<InstanceDescriptor>(m, "instance_descriptor")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def(py::init<std::unordered_set<int>, std::unordered_set<int>>())
|
||||||
|
.def_readonly("divisible_by_16", &InstanceDescriptor::divisibleBy16)
|
||||||
|
.def_readonly("equal_to_1", &InstanceDescriptor::equalTo1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -758,10 +771,11 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("get", &ir::struct_type::get, ret::reference)
|
.def("get", &ir::struct_type::get, ret::reference)
|
||||||
.def_property_readonly("num_types", &ir::struct_type::get_num_types);
|
.def_property_readonly("num_types", &ir::struct_type::get_num_types);
|
||||||
|
|
||||||
py::class_<ir::module>(m, "module")
|
py::class_<ir::module>(m, "module", py::dynamic_attr())
|
||||||
.def(py::init<std::string, ir::builder &>())
|
.def(py::init<std::string, ir::builder &>())
|
||||||
.def("has_function", &ir::module::has_function)
|
.def("has_function", &ir::module::has_function)
|
||||||
.def("get_function", &ir::module::get_function, ret::reference)
|
.def("get_function", &ir::module::get_function, ret::reference)
|
||||||
|
.def("get_functions", &ir::module::get_function_list, ret::reference)
|
||||||
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
|
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
|
||||||
.def("print", [](ir::module *self) {
|
.def("print", [](ir::module *self) {
|
||||||
self->print(std::cout);
|
self->print(std::cout);
|
||||||
|
@@ -11,7 +11,7 @@ from numpy.random import RandomState
|
|||||||
import triton
|
import triton
|
||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from triton.code_gen import JITFunction, TensorWrapper, reinterpret
|
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
|
||||||
|
|
||||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||||
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||||
@@ -273,7 +273,7 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
|||||||
elif (op in ('%', '/') and
|
elif (op in ('%', '/') and
|
||||||
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
|
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
|
||||||
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
|
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
|
||||||
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
with pytest.raises(triton.CompilationError) as exc_info:
|
||||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||||
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
|
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
|
||||||
else:
|
else:
|
||||||
@@ -311,7 +311,7 @@ def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
|||||||
else:
|
else:
|
||||||
numpy_expr = None
|
numpy_expr = None
|
||||||
if 'float' in dtype_x + dtype_y:
|
if 'float' in dtype_x + dtype_y:
|
||||||
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
with pytest.raises(triton.CompilationError) as exc_info:
|
||||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
|
||||||
# The CompilationError must have been caused by a C++ exception with this text.
|
# The CompilationError must have been caused by a C++ exception with this text.
|
||||||
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
|
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
|
||||||
@@ -500,7 +500,7 @@ def test_index1d(expr, dtype_str, device='cuda'):
|
|||||||
def catch_compilation_error(kernel):
|
def catch_compilation_error(kernel):
|
||||||
try:
|
try:
|
||||||
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||||
except triton.code_gen.CompilationError as e:
|
except triton.CompilationError as e:
|
||||||
np.testing.assert_(True)
|
np.testing.assert_(True)
|
||||||
except BaseException:
|
except BaseException:
|
||||||
np.testing.assert_(False)
|
np.testing.assert_(False)
|
||||||
@@ -1209,7 +1209,7 @@ def test_load_cache_modifier(cache):
|
|||||||
assert 'ld.global.cg' not in ptx
|
assert 'ld.global.cg' not in ptx
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("N", [8, 10, 11, 1024])
|
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
|
||||||
def test_vectorization(N):
|
def test_vectorization(N):
|
||||||
src = torch.empty(1024, device='cuda')
|
src = torch.empty(1024, device='cuda')
|
||||||
dst = torch.empty(1024, device='cuda')
|
dst = torch.empty(1024, device='cuda')
|
||||||
@@ -1221,10 +1221,8 @@ def test_vectorization(N):
|
|||||||
tl.store(dst + offsets, x, mask=offsets < N)
|
tl.store(dst + offsets, x, mask=offsets < N)
|
||||||
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
||||||
ptx = pgm.asm["ptx"]
|
ptx = pgm.asm["ptx"]
|
||||||
if N % 4 == 0:
|
if N % 16 == 0:
|
||||||
assert "ld.global.v4.b32" in ptx
|
assert "ld.global.v4.b32" in ptx
|
||||||
elif N % 2 == 0:
|
|
||||||
assert "ld.global.v2.b32" in ptx
|
|
||||||
else:
|
else:
|
||||||
assert "ld.global.b32" in ptx
|
assert "ld.global.b32" in ptx
|
||||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||||
@@ -1292,7 +1290,7 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
|
|||||||
|
|
||||||
def cache_hook(*args, **kwargs):
|
def cache_hook(*args, **kwargs):
|
||||||
nonlocal spec_type
|
nonlocal spec_type
|
||||||
spec_type = kwargs["compile"]["arg_types"][0][1]
|
spec_type = kwargs["compile"]["signature"][0]
|
||||||
JITFunction.cache_hook = cache_hook
|
JITFunction.cache_hook = cache_hook
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -1319,7 +1317,7 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
|
|||||||
x = torch.tensor([3.14159], device='cuda')
|
x = torch.tensor([3.14159], device='cuda')
|
||||||
|
|
||||||
if overflow:
|
if overflow:
|
||||||
with pytest.raises(RuntimeError, match='integer overflow'):
|
with pytest.raises(OverflowError):
|
||||||
kernel[(1, )](value, x)
|
kernel[(1, )](value, x)
|
||||||
else:
|
else:
|
||||||
kernel[(1, )](value, x)
|
kernel[(1, )](value, x)
|
||||||
|
@@ -78,10 +78,9 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
|||||||
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
|
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
|
||||||
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
|
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
|
||||||
kernel = triton.ops._matmul.kernel
|
kernel = triton.ops._matmul.kernel
|
||||||
decorators = kernel.kernel_decorators
|
kernel.configs = configs
|
||||||
kernel.kernel_decorators = []
|
# kernel.run = kernel.run.run.run
|
||||||
triton.autotune(configs, [])(kernel)
|
|
||||||
kernel.kernel_decorators += decorators[1:]
|
|
||||||
# get matrix shape
|
# get matrix shape
|
||||||
M = BLOCK_M if M is None else M
|
M = BLOCK_M if M is None else M
|
||||||
N = BLOCK_N if N is None else N
|
N = BLOCK_N if N is None else N
|
||||||
|
@@ -7,7 +7,7 @@ import torch
|
|||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from triton.code_gen import JITFunction
|
from triton.runtime.jit import JITFunction
|
||||||
|
|
||||||
tmpdir = ".tmp"
|
tmpdir = ".tmp"
|
||||||
|
|
||||||
@@ -99,16 +99,16 @@ def test_specialize(mode):
|
|||||||
reset_tmp_dir()
|
reset_tmp_dir()
|
||||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||||
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
|
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
|
||||||
target = {'enable': 5, 'disable': 1}[mode]
|
target = {'enable': 3, 'disable': 1}[mode]
|
||||||
for i in [1, 2, 4, 8, 16, 32]:
|
for i in [1, 2, 4, 8, 16, 32]:
|
||||||
function[(1,)](x, i, BLOCK=512)
|
function[(1,)](x, i, BLOCK=512)
|
||||||
assert counter == target
|
assert counter == target
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("value, value_type", [
|
@pytest.mark.parametrize("value, value_type", [
|
||||||
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
|
(-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||||
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
|
(2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'),
|
||||||
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
|
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||||
])
|
])
|
||||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||||
|
|
||||||
@@ -120,14 +120,14 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
|
|||||||
|
|
||||||
def get_cache_str(*args, **kwargs):
|
def get_cache_str(*args, **kwargs):
|
||||||
nonlocal cache_str
|
nonlocal cache_str
|
||||||
cache_str = kwargs['key'].split('-')
|
cache_str = kwargs["repr"]
|
||||||
triton.code_gen.JITFunction.cache_hook = get_cache_str
|
triton.JITFunction.cache_hook = get_cache_str
|
||||||
reset_tmp_dir()
|
reset_tmp_dir()
|
||||||
x = torch.tensor([3.14159], device='cuda')
|
x = torch.tensor([3.14159], device='cuda')
|
||||||
kernel[(1, )](value, x)
|
kernel[(1, )](value, x)
|
||||||
triton.code_gen.JITFunction.cache_hook = None
|
triton.JITFunction.cache_hook = None
|
||||||
|
|
||||||
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
|
cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str)
|
||||||
spec_type = None if cache_str_match is None else cache_str_match.group(1)
|
spec_type = None if cache_str_match is None else cache_str_match.group(1)
|
||||||
assert spec_type == value_type
|
assert spec_type == value_type
|
||||||
|
|
||||||
|
@@ -6,9 +6,10 @@ __version__ = '2.0.0'
|
|||||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||||
import torch
|
import torch
|
||||||
# submodules
|
# submodules
|
||||||
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \
|
from .utils import *
|
||||||
JITFunction, Config, Autotuner, reinterpret
|
from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface
|
||||||
|
from .runtime.jit import jit
|
||||||
|
from .compiler import compile, CompilationError
|
||||||
from . import language
|
from . import language
|
||||||
from . import code_gen
|
|
||||||
from . import testing
|
from . import testing
|
||||||
from . import ops
|
from . import ops
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -420,7 +420,7 @@ class tensor:
|
|||||||
self.numel = 1
|
self.numel = 1
|
||||||
for s in self.shape:
|
for s in self.shape:
|
||||||
self.numel *= s
|
self.numel *= s
|
||||||
is_pow2 = (self.numel and (not(self.numel & (self.numel - 1))))
|
is_pow2 = (self.numel and (not (self.numel & (self.numel - 1))))
|
||||||
if not is_pow2:
|
if not is_pow2:
|
||||||
raise ValueError("Triton tensors must have a power-of-two number of elements")
|
raise ValueError("Triton tensors must have a power-of-two number of elements")
|
||||||
self.numel = constexpr(self.numel)
|
self.numel = constexpr(self.numel)
|
||||||
|
@@ -18,8 +18,8 @@ def num_warps(n):
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _blocksparse_softmax_fwd(
|
def _blocksparse_softmax_fwd(
|
||||||
Out, A, stride_xz, LUT,
|
Out, A, LUT, R, stride_xz,
|
||||||
R, extent, stride_zr, stride_hr, # relative attention
|
extent, stride_zr, stride_hr, # relative attention
|
||||||
scale, is_causal,
|
scale, is_causal,
|
||||||
ROW_SIZE: tl.constexpr,
|
ROW_SIZE: tl.constexpr,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
@@ -164,8 +164,8 @@ class _softmax(torch.autograd.Function):
|
|||||||
# enqueue kernel
|
# enqueue kernel
|
||||||
out = torch.empty_like(a)
|
out = torch.empty_like(a)
|
||||||
_blocksparse_softmax_fwd[grid](
|
_blocksparse_softmax_fwd[grid](
|
||||||
out, a, a.stride(0), lut,
|
out, a, lut, rel_logits, a.stride(0),
|
||||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
||||||
scale,
|
scale,
|
||||||
is_causal,
|
is_causal,
|
||||||
BLOCK_SIZE=block,
|
BLOCK_SIZE=block,
|
||||||
|
@@ -26,9 +26,6 @@ def get_configs_io_bound():
|
|||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({
|
|
||||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
|
||||||
})
|
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
# basic configs for compute-bound matmuls
|
# basic configs for compute-bound matmuls
|
||||||
@@ -59,6 +56,9 @@ def get_configs_io_bound():
|
|||||||
'top_k': 10
|
'top_k': 10
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@triton.heuristics({
|
||||||
|
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||||
|
})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _kernel(A, B, C, M, N, K,
|
def _kernel(A, B, C, M, N, K,
|
||||||
stride_am, stride_ak,
|
stride_am, stride_ak,
|
||||||
|
2
python/triton/runtime/__init__.py
Normal file
2
python/triton/runtime/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401
|
||||||
|
from .jit import JITFunction, KernelInterface, version_key # noqa: F401
|
204
python/triton/runtime/autotuner.py
Normal file
204
python/triton/runtime/autotuner.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
import time
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from ..testing import do_bench
|
||||||
|
from .jit import KernelInterface
|
||||||
|
|
||||||
|
|
||||||
|
class Autotuner(KernelInterface):
|
||||||
|
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
|
||||||
|
'''
|
||||||
|
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||||
|
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||||
|
'top_k': number of configs to bench
|
||||||
|
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||||
|
'''
|
||||||
|
if not configs:
|
||||||
|
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
|
||||||
|
else:
|
||||||
|
self.configs = configs
|
||||||
|
self.key_idx = [arg_names.index(k) for k in key]
|
||||||
|
self.cache = dict()
|
||||||
|
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||||
|
self.hook = lambda args: 0
|
||||||
|
if reset_to_zero is not None:
|
||||||
|
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||||
|
|
||||||
|
def _hook(args):
|
||||||
|
for i in self.reset_idx:
|
||||||
|
args[i].zero_()
|
||||||
|
self.hook = _hook
|
||||||
|
self.arg_names = arg_names
|
||||||
|
# prune configs
|
||||||
|
if prune_configs_by:
|
||||||
|
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||||
|
if 'early_config_prune' in prune_configs_by:
|
||||||
|
early_config_prune = prune_configs_by['early_config_prune']
|
||||||
|
else:
|
||||||
|
perf_model, top_k, early_config_prune = None, None, None
|
||||||
|
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||||
|
self.early_config_prune = early_config_prune
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def _bench(self, *args, config, **meta):
|
||||||
|
# check for conflicts, i.e. meta-parameters both provided
|
||||||
|
# as kwargs and by the autotuner
|
||||||
|
conflicts = meta.keys() & config.kwargs.keys()
|
||||||
|
if conflicts:
|
||||||
|
raise ValueError(
|
||||||
|
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||||
|
" Make sure that you don't re-define auto-tuned symbols."
|
||||||
|
)
|
||||||
|
# augment meta-parameters with tunable ones
|
||||||
|
current = dict(meta, **config.kwargs)
|
||||||
|
|
||||||
|
def kernel_call():
|
||||||
|
if config.pre_hook:
|
||||||
|
config.pre_hook(self.nargs)
|
||||||
|
self.hook(args)
|
||||||
|
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||||
|
return do_bench(kernel_call)
|
||||||
|
|
||||||
|
def run(self, *args, **kwargs):
|
||||||
|
self.nargs = dict(zip(self.arg_names, args))
|
||||||
|
if len(self.configs) > 1:
|
||||||
|
key = tuple([args[i] for i in self.key_idx])
|
||||||
|
if key not in self.cache:
|
||||||
|
# prune configs
|
||||||
|
pruned_configs = self.configs
|
||||||
|
if self.early_config_prune:
|
||||||
|
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||||
|
if self.perf_model:
|
||||||
|
top_k = self.configs_top_k
|
||||||
|
if isinstance(top_k, float) and top_k <= 1.0:
|
||||||
|
top_k = int(len(self.configs) * top_k)
|
||||||
|
if len(pruned_configs) > top_k:
|
||||||
|
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
||||||
|
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||||
|
bench_start = time.time()
|
||||||
|
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||||
|
for config in pruned_configs}
|
||||||
|
bench_end = time.time()
|
||||||
|
self.bench_time = bench_end - bench_start
|
||||||
|
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||||
|
self.hook(args)
|
||||||
|
self.configs_timings = timings
|
||||||
|
config = self.cache[key]
|
||||||
|
else:
|
||||||
|
config = self.configs[0]
|
||||||
|
self.best_config = config
|
||||||
|
if config.pre_hook is not None:
|
||||||
|
config.pre_hook(self.nargs)
|
||||||
|
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""
|
||||||
|
An object that represents a possible kernel configuration for the auto-tuner to try.
|
||||||
|
|
||||||
|
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
||||||
|
:type meta: dict[Str, Any]
|
||||||
|
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
|
||||||
|
`num_warps=8`, then each kernel instance will be automatically parallelized to
|
||||||
|
cooperatively execute using `8 * 32 = 256` threads.
|
||||||
|
:type num_warps: int
|
||||||
|
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
||||||
|
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
||||||
|
:type num_stages: int
|
||||||
|
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
||||||
|
function are args.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.num_warps = num_warps
|
||||||
|
self.num_stages = num_stages
|
||||||
|
self.pre_hook = pre_hook
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
res = []
|
||||||
|
for k, v in self.kwargs.items():
|
||||||
|
res.append(f'{k}: {v}')
|
||||||
|
res.append(f'num_warps: {self.num_warps}')
|
||||||
|
res.append(f'num_stages: {self.num_stages}')
|
||||||
|
return ', '.join(res)
|
||||||
|
|
||||||
|
|
||||||
|
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||||
|
"""
|
||||||
|
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||||
|
|
||||||
|
.. highlight:: python
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@triton.autotune(configs=[
|
||||||
|
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||||
|
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||||
|
],
|
||||||
|
key=['x_size'] # the two above configs will be evaluated anytime
|
||||||
|
# the value of x_size changes
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def kernel(x_ptr, x_size, **META):
|
||||||
|
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||||
|
|
||||||
|
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||||
|
This means that whatever value the kernel updates will be updated multiple times.
|
||||||
|
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||||
|
reset the value of the provided tensor to `zero` before running any configuration.
|
||||||
|
|
||||||
|
:param configs: a list of :code:`triton.Config` objects
|
||||||
|
:type configs: list[triton.Config]
|
||||||
|
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||||
|
:type key: list[str]
|
||||||
|
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||||
|
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||||
|
'top_k': number of configs to bench
|
||||||
|
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
|
||||||
|
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||||
|
:type reset_to_zero: list[str]
|
||||||
|
"""
|
||||||
|
def decorator(fn):
|
||||||
|
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class Heuristics(KernelInterface):
|
||||||
|
|
||||||
|
def __init__(self, fn, arg_names, values) -> None:
|
||||||
|
self.fn = fn
|
||||||
|
self.values = values
|
||||||
|
self.arg_names = arg_names
|
||||||
|
|
||||||
|
def run(self, *args, **kwargs):
|
||||||
|
for v, heur in self.values.items():
|
||||||
|
kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
|
||||||
|
return self.fn.run(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def heuristics(values):
|
||||||
|
"""
|
||||||
|
Decorator for specifying how the values of certain meta-parameters may be computed.
|
||||||
|
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
|
||||||
|
|
||||||
|
.. highlight:: python
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
|
||||||
|
@triton.jit
|
||||||
|
def kernel(x_ptr, x_size, **META):
|
||||||
|
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
|
||||||
|
|
||||||
|
|
||||||
|
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
||||||
|
each such function takes a list of positional arguments as input.
|
||||||
|
.type values: dict[str, Callable[[list[Any]], Any]]
|
||||||
|
"""
|
||||||
|
def decorator(fn):
|
||||||
|
return Heuristics(fn, fn.arg_names, values)
|
||||||
|
|
||||||
|
return decorator
|
415
python/triton/runtime/jit.py
Normal file
415
python/triton/runtime/jit.py
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
from __future__ import annotations, division
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import functools
|
||||||
|
import hashlib
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import textwrap
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import triton
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
||||||
|
except ImportError:
|
||||||
|
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Dependencies Finder
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class DependenciesFinder(ast.NodeVisitor):
|
||||||
|
"""
|
||||||
|
This AST visitor is used to find dependencies of a JITFunction. This can
|
||||||
|
be used to invalidate a JITFunction's hash when its source code -- or
|
||||||
|
that of its dependencies -- changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, globals, src) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.ret = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||||
|
self.globals = globals
|
||||||
|
|
||||||
|
def visit_Name(self, node):
|
||||||
|
return self.globals.get(node.id, None)
|
||||||
|
|
||||||
|
def visit_Attribute(self, node):
|
||||||
|
lhs = self.visit(node.value)
|
||||||
|
while isinstance(lhs, ast.Attribute):
|
||||||
|
lhs = self.visit(lhs.value)
|
||||||
|
if lhs is None or lhs is triton:
|
||||||
|
return None
|
||||||
|
return getattr(lhs, node.attr)
|
||||||
|
|
||||||
|
def visit_Call(self, node):
|
||||||
|
func = self.visit(node.func)
|
||||||
|
if func is None:
|
||||||
|
return
|
||||||
|
if inspect.isbuiltin(func):
|
||||||
|
return
|
||||||
|
if func.__module__ and func.__module__.startswith('triton.'):
|
||||||
|
return
|
||||||
|
assert isinstance(func, JITFunction)
|
||||||
|
if func.hash is None:
|
||||||
|
tree = ast.parse(func.src)
|
||||||
|
finder = DependenciesFinder(func.__globals__, func.src)
|
||||||
|
finder.visit(tree)
|
||||||
|
func.hash = finder.ret
|
||||||
|
self.ret = (self.ret + func.hash).encode("utf-8")
|
||||||
|
self.ret = hashlib.md5(self.ret).hexdigest()
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# JITFunction
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache()
|
||||||
|
def version_key():
|
||||||
|
import pkgutil
|
||||||
|
contents = []
|
||||||
|
# frontend
|
||||||
|
with open(__file__, "rb") as f:
|
||||||
|
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||||
|
with open(triton.compiler.__file__, "rb") as f:
|
||||||
|
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||||
|
# backend
|
||||||
|
with open(triton._C.libtriton.__file__, "rb") as f:
|
||||||
|
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||||
|
# language
|
||||||
|
language_path = os.path.join(*triton.__path__, 'language')
|
||||||
|
for lib in pkgutil.iter_modules([language_path]):
|
||||||
|
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||||
|
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||||
|
# ptxas version
|
||||||
|
try:
|
||||||
|
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
|
||||||
|
except Exception:
|
||||||
|
ptxas_version = ''
|
||||||
|
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||||
|
|
||||||
|
|
||||||
|
class KernelInterface:
|
||||||
|
|
||||||
|
def __getitem__(self, grid):
|
||||||
|
"""
|
||||||
|
A JIT function is launched with: fn[grid](*args, **kwargs).
|
||||||
|
Hence JITFunction.__getitem__ returns a callable proxy that
|
||||||
|
memorizes the grid.
|
||||||
|
"""
|
||||||
|
def launcher(*args, **kwargs):
|
||||||
|
return self.run(*args, grid=grid, **kwargs)
|
||||||
|
return launcher
|
||||||
|
|
||||||
|
|
||||||
|
class JITFunction(KernelInterface):
|
||||||
|
|
||||||
|
cache_hook = None
|
||||||
|
divisibility = 16
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _key_of(arg):
|
||||||
|
if hasattr(arg, "dtype"):
|
||||||
|
return arg.dtype
|
||||||
|
elif isinstance(arg, bool):
|
||||||
|
return "i1"
|
||||||
|
elif isinstance(arg, int):
|
||||||
|
if -2**31 <= arg and arg <= 2**31 - 1:
|
||||||
|
return "i32"
|
||||||
|
elif 2**31 <= arg and arg <= 2**32 - 1:
|
||||||
|
return "u32"
|
||||||
|
elif 2**63 <= arg and arg <= 2**64 - 1:
|
||||||
|
return "u64"
|
||||||
|
else:
|
||||||
|
return "i64"
|
||||||
|
elif isinstance(arg, float):
|
||||||
|
return 'fp32'
|
||||||
|
elif arg is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _spec_of(arg):
|
||||||
|
if hasattr(arg, "data_ptr"):
|
||||||
|
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
||||||
|
elif isinstance(arg, int):
|
||||||
|
return (arg % 16 == 0, arg == 1)
|
||||||
|
return (arg is None, )
|
||||||
|
|
||||||
|
def _get_config(self, *args):
|
||||||
|
def is_divisible_by_16(x):
|
||||||
|
if hasattr(x, "data_ptr"):
|
||||||
|
return x.data_ptr() % JITFunction.divisibility == 0
|
||||||
|
elif isinstance(x, int):
|
||||||
|
return x % JITFunction.divisibility == 0
|
||||||
|
if x is None:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||||
|
equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
|
||||||
|
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
|
||||||
|
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _type_of(key):
|
||||||
|
if isinstance(key, (torch.dtype, triton.language.dtype)):
|
||||||
|
ty = {
|
||||||
|
torch.bool: 'i1',
|
||||||
|
torch.float16: 'fp16',
|
||||||
|
torch.bfloat16: 'bf16',
|
||||||
|
torch.float32: 'fp32',
|
||||||
|
torch.float64: 'fp64',
|
||||||
|
torch.uint8: 'u8',
|
||||||
|
torch.int8: 'i8',
|
||||||
|
torch.int16: 'i16',
|
||||||
|
torch.int32: 'i32',
|
||||||
|
torch.int64: 'i64',
|
||||||
|
|
||||||
|
triton.language.uint8: 'u8',
|
||||||
|
triton.language.uint16: 'u16',
|
||||||
|
triton.language.uint32: 'u32',
|
||||||
|
triton.language.uint64: 'u64',
|
||||||
|
triton.language.float8: 'fp8',
|
||||||
|
}[key]
|
||||||
|
return f'*{ty}'
|
||||||
|
if key is None:
|
||||||
|
return '*i8'
|
||||||
|
assert isinstance(key, str)
|
||||||
|
return key
|
||||||
|
|
||||||
|
def _make_signature(self, sig_key):
|
||||||
|
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
|
||||||
|
return signature
|
||||||
|
|
||||||
|
def _make_constants(self, constexpr_key):
|
||||||
|
constants = {i: k for i, k in zip(self.constexprs, constexpr_key)}
|
||||||
|
return constants
|
||||||
|
|
||||||
|
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||||
|
if JITFunction.cache_hook is None:
|
||||||
|
return False
|
||||||
|
name = self.fn.__name__
|
||||||
|
module = self.fn.__module__
|
||||||
|
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||||
|
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
|
||||||
|
key = str(key)
|
||||||
|
|
||||||
|
class LegacyCompiler:
|
||||||
|
def __init__(self, module, name):
|
||||||
|
self.module = module
|
||||||
|
self.name = name
|
||||||
|
pass
|
||||||
|
|
||||||
|
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||||
|
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
|
||||||
|
configs=configs)
|
||||||
|
|
||||||
|
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
||||||
|
|
||||||
|
def _make_launcher(self):
|
||||||
|
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||||
|
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
|
||||||
|
args = ', '.join(regular_args)
|
||||||
|
# cache key for regular argument type
|
||||||
|
sig_keys = ', '.join([f'_key_of({arg})' for arg in regular_args])
|
||||||
|
# cache key for constexpr argument values
|
||||||
|
constexpr_keys = ', '.join(constexpr_args)
|
||||||
|
# cache key for argument specialization
|
||||||
|
specializations = []
|
||||||
|
for i, arg in enumerate(regular_args):
|
||||||
|
if i in self.do_not_specialize:
|
||||||
|
continue
|
||||||
|
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") '
|
||||||
|
f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) '
|
||||||
|
f'else (False,)']
|
||||||
|
spec_keys = ', '.join(specializations)
|
||||||
|
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||||
|
|
||||||
|
src = f"""
|
||||||
|
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None):
|
||||||
|
sig_key = {sig_keys},
|
||||||
|
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()}
|
||||||
|
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()}
|
||||||
|
key = (version_key, sig_key, constexpr_key, spec_key)
|
||||||
|
if not extern_libs is None:
|
||||||
|
key = (key, tuple(extern_libs.items()))
|
||||||
|
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
||||||
|
if callable(grid):
|
||||||
|
grid = grid({{{grid_args}}})
|
||||||
|
grid_size = len(grid)
|
||||||
|
grid_0 = grid[0]
|
||||||
|
grid_1 = grid[1] if grid_size > 1 else 1
|
||||||
|
grid_2 = grid[2] if grid_size > 2 else 1
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
if stream is None:
|
||||||
|
stream = get_cuda_stream(device)
|
||||||
|
try:
|
||||||
|
bin = cache[key]
|
||||||
|
bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args})
|
||||||
|
return bin
|
||||||
|
# kernel not cached -- compile
|
||||||
|
except KeyError:
|
||||||
|
# build dict of constant values
|
||||||
|
args = [{args}]
|
||||||
|
configs = self._get_config(*args),
|
||||||
|
constants = self._make_constants(constexpr_key)
|
||||||
|
constants.update({{i: None for i, arg in enumerate(args) if arg is None}})
|
||||||
|
constants.update({{i: 1 for i in configs[0].equal_to_1}})
|
||||||
|
# build kernel signature -- doesn't include specialized arguments
|
||||||
|
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
||||||
|
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
|
||||||
|
# build stub signature -- includes arguments that are specialized
|
||||||
|
for i, arg in constants.items():
|
||||||
|
if callable(arg):
|
||||||
|
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||||
|
device = 0
|
||||||
|
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||||
|
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
||||||
|
bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args)
|
||||||
|
self.cache[key] = bin
|
||||||
|
return bin
|
||||||
|
return None
|
||||||
|
"""
|
||||||
|
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
|
||||||
|
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
|
||||||
|
"cache": self.cache, "triton": triton, "torch": torch}
|
||||||
|
exec(src, scope)
|
||||||
|
return scope[self.fn.__name__]
|
||||||
|
|
||||||
|
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||||
|
self.fn = fn
|
||||||
|
self.module = fn.__module__
|
||||||
|
self.version = version
|
||||||
|
# function signature information
|
||||||
|
signature = inspect.signature(fn)
|
||||||
|
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||||
|
self.has_defaults = any([v.default != inspect._empty for v in signature.parameters.values()])
|
||||||
|
# specialization hints
|
||||||
|
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||||
|
self.do_not_specialize = set([self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize])
|
||||||
|
# function source code (without decorators)
|
||||||
|
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||||
|
self.src = self.src[self.src.find("def"):]
|
||||||
|
# cache of just-in-time compiled kernels
|
||||||
|
self.cache = dict()
|
||||||
|
self.hash = None
|
||||||
|
# JITFunction can be instantiated as kernel
|
||||||
|
# when called with a grid using __getitem__
|
||||||
|
self.kernel_decorators = []
|
||||||
|
self.kernel = None
|
||||||
|
# annotations
|
||||||
|
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
||||||
|
self.__annotations__ = fn.__annotations__
|
||||||
|
# index of constexprs
|
||||||
|
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
|
||||||
|
# launcher
|
||||||
|
self.run = self._make_launcher()
|
||||||
|
# re-use docs of wrapped function
|
||||||
|
self.__doc__ = fn.__doc__
|
||||||
|
self.__name__ = fn.__name__
|
||||||
|
self.__globals__ = fn.__globals__
|
||||||
|
self.__module__ = fn.__module__
|
||||||
|
|
||||||
|
@property
|
||||||
|
@functools.lru_cache()
|
||||||
|
def cache_key(self):
|
||||||
|
# TODO : hash should be attribute of `self`
|
||||||
|
if self.hash is None:
|
||||||
|
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
|
||||||
|
dependencies_finder.visit(self.parse())
|
||||||
|
self.hash = dependencies_finder.ret + version_key()
|
||||||
|
return self.hash
|
||||||
|
|
||||||
|
# we do not parse `src` in the constructor because
|
||||||
|
# the user might want to monkey-patch self.src dynamically.
|
||||||
|
# Our unit tests do this, for example.
|
||||||
|
def parse(self):
|
||||||
|
tree = ast.parse(self.src)
|
||||||
|
assert isinstance(tree, ast.Module)
|
||||||
|
assert len(tree.body) == 1
|
||||||
|
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||||
|
return tree
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
# - when kernel decorators change, cached kernel
|
||||||
|
# needs to be cleared
|
||||||
|
if name == 'kernel_decorators':
|
||||||
|
self.kernel = None
|
||||||
|
super(JITFunction, self).__setattr__(name, value)
|
||||||
|
# - when `.src` attribute is set, cache path needs
|
||||||
|
# to be reinitialized
|
||||||
|
if name == 'src':
|
||||||
|
self.hash = None
|
||||||
|
JITFunction.cache_key.fget.cache_clear()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"JITFunction({self.module}:{self.fn.__name__})"
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# `jit` decorator
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def jit(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Decorator for JIT-compiling a function using the Triton compiler.
|
||||||
|
|
||||||
|
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||||
|
|
||||||
|
:note: This function will be compiled and run on the GPU. It will only have access to:
|
||||||
|
|
||||||
|
* python primitives,
|
||||||
|
* objects within the triton.language package,
|
||||||
|
* arguments to this function,
|
||||||
|
* other jit'd functions
|
||||||
|
|
||||||
|
:param fn: the function to be jit-compiled
|
||||||
|
:type fn: Callable
|
||||||
|
"""
|
||||||
|
if args:
|
||||||
|
assert len(args) == 1
|
||||||
|
assert callable(args[0])
|
||||||
|
return JITFunction(args[0], **kwargs)
|
||||||
|
else:
|
||||||
|
def decorator(fn):
|
||||||
|
return JITFunction(fn, **kwargs)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class TensorWrapper:
|
||||||
|
def __init__(self, base, dtype):
|
||||||
|
self.dtype = dtype
|
||||||
|
self.base = base
|
||||||
|
self.is_cuda = base.is_cuda
|
||||||
|
self.device = base.device
|
||||||
|
|
||||||
|
def data_ptr(self):
|
||||||
|
return self.base.data_ptr()
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||||
|
|
||||||
|
|
||||||
|
def reinterpret(tensor, dtype):
|
||||||
|
if isinstance(tensor, TensorWrapper):
|
||||||
|
if dtype == tensor.base.dtype:
|
||||||
|
# Reinterpreting to the original interpretation; return the base.
|
||||||
|
return tensor.base
|
||||||
|
else:
|
||||||
|
# Reinterpreting a wrapped tensor to a different type.
|
||||||
|
return TensorWrapper(tensor.base, dtype)
|
||||||
|
elif isinstance(tensor, torch.Tensor):
|
||||||
|
# A new wrapper is needed around an unwrapped tensor.
|
||||||
|
return TensorWrapper(tensor, dtype)
|
||||||
|
else:
|
||||||
|
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
@@ -7,7 +7,7 @@ from contextlib import contextmanager
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
from .code_gen import OutOfResources
|
from .compiler import OutOfResources
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton._C.libtriton.cutlass as _cutlass
|
import triton._C.libtriton.cutlass as _cutlass
|
||||||
|
48
python/triton/utils.py
Normal file
48
python/triton/utils.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def cdiv(x, y):
|
||||||
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
|
def next_power_of_2(n):
|
||||||
|
"""Return the smallest power of 2 greater than or equal to n"""
|
||||||
|
n -= 1
|
||||||
|
n |= n >> 1
|
||||||
|
n |= n >> 2
|
||||||
|
n |= n >> 4
|
||||||
|
n |= n >> 8
|
||||||
|
n |= n >> 16
|
||||||
|
n += 1
|
||||||
|
return n
|
||||||
|
|
||||||
|
|
||||||
|
class TensorWrapper:
|
||||||
|
def __init__(self, base, dtype):
|
||||||
|
self.dtype = dtype
|
||||||
|
self.base = base
|
||||||
|
self.is_cuda = base.is_cuda
|
||||||
|
self.device = base.device
|
||||||
|
|
||||||
|
def data_ptr(self):
|
||||||
|
return self.base.data_ptr()
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||||
|
|
||||||
|
|
||||||
|
def reinterpret(tensor, dtype):
|
||||||
|
if isinstance(tensor, TensorWrapper):
|
||||||
|
if dtype == tensor.base.dtype:
|
||||||
|
# Reinterpreting to the original interpretation; return the base.
|
||||||
|
return tensor.base
|
||||||
|
else:
|
||||||
|
# Reinterpreting a wrapped tensor to a different type.
|
||||||
|
return TensorWrapper(tensor.base, dtype)
|
||||||
|
elif isinstance(tensor, torch.Tensor):
|
||||||
|
# A new wrapper is needed around an unwrapped tensor.
|
||||||
|
return TensorWrapper(tensor, dtype)
|
||||||
|
else:
|
||||||
|
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
Reference in New Issue
Block a user