diff --git a/include/triton/codegen/extern_lib.h b/include/triton/codegen/extern_lib.h index c161ff142..02e991407 100644 --- a/include/triton/codegen/extern_lib.h +++ b/include/triton/codegen/extern_lib.h @@ -3,6 +3,7 @@ #include #include +#include #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index 1ed0b6646..d09a51a22 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -87,7 +87,6 @@ public: // Functions const functions_list_t &get_function_list() const { return functions_; } - functions_list_t &get_function_list() { return functions_; } function *get_function(const std::string& name) { if(symbols_.find(name) == symbols_.end()) throw std::runtime_error("function " + name + " is not declared"); diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 645f10978..024a838d9 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -106,11 +106,11 @@ std::unique_ptr add_passes_to_emit_bin( // run passes inliner.run(ir); dce.run(ir); - // ir.print(std::cout); peephole.run(ir); dce.run(ir); pipeline.run(ir); dce.run(ir); + // ir.print(std::cout); disassociate.run(ir); dce.run(ir); align.run(ir); diff --git a/python/src/triton.cc b/python/src/triton.cc index d56ff8430..8bfb076c3 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -574,6 +574,19 @@ void init_triton_codegen(py::module &&m) { assert(backend == ROCM); return hip_load_binary(name, asm_map, n_shared_bytes, dev); }, py::return_value_policy::take_ownership); + + + struct InstanceDescriptor + { + std::unordered_set divisibleBy16; + std::unordered_set equalTo1; + }; + + py::class_(m, "instance_descriptor") + .def(py::init<>()) + .def(py::init, std::unordered_set>()) + .def_readonly("divisible_by_16", &InstanceDescriptor::divisibleBy16) + .def_readonly("equal_to_1", &InstanceDescriptor::equalTo1); } @@ -758,10 +771,11 @@ void init_triton_ir(py::module &&m) { .def("get", &ir::struct_type::get, ret::reference) .def_property_readonly("num_types", &ir::struct_type::get_num_types); - py::class_(m, "module") + py::class_(m, "module", py::dynamic_attr()) .def(py::init()) .def("has_function", &ir::module::has_function) .def("get_function", &ir::module::get_function, ret::reference) + .def("get_functions", &ir::module::get_function_list, ret::reference) .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference) .def("print", [](ir::module *self) { self->print(std::cout); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 46ddfe760..d00de5be5 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -11,7 +11,7 @@ from numpy.random import RandomState import triton import triton._C.libtriton.triton as _triton import triton.language as tl -from triton.code_gen import JITFunction, TensorWrapper, reinterpret +from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret int_dtypes = ['int8', 'int16', 'int32', 'int64'] uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] @@ -273,7 +273,7 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'): elif (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or (dtype_x in uint_dtypes and dtype_y in int_dtypes))): - with pytest.raises(triton.code_gen.CompilationError) as exc_info: + with pytest.raises(triton.CompilationError) as exc_info: _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__)) else: @@ -311,7 +311,7 @@ def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): else: numpy_expr = None if 'float' in dtype_x + dtype_y: - with pytest.raises(triton.code_gen.CompilationError) as exc_info: + with pytest.raises(triton.CompilationError) as exc_info: _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device) # The CompilationError must have been caused by a C++ exception with this text. assert re.match('invalid operands of type', str(exc_info.value.__cause__)) @@ -500,7 +500,7 @@ def test_index1d(expr, dtype_str, device='cuda'): def catch_compilation_error(kernel): try: kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) - except triton.code_gen.CompilationError as e: + except triton.CompilationError as e: np.testing.assert_(True) except BaseException: np.testing.assert_(False) @@ -1209,7 +1209,7 @@ def test_load_cache_modifier(cache): assert 'ld.global.cg' not in ptx -@pytest.mark.parametrize("N", [8, 10, 11, 1024]) +@pytest.mark.parametrize("N", [16, 10, 11, 1024]) def test_vectorization(N): src = torch.empty(1024, device='cuda') dst = torch.empty(1024, device='cuda') @@ -1221,10 +1221,8 @@ def test_vectorization(N): tl.store(dst + offsets, x, mask=offsets < N) pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0]) ptx = pgm.asm["ptx"] - if N % 4 == 0: + if N % 16 == 0: assert "ld.global.v4.b32" in ptx - elif N % 2 == 0: - assert "ld.global.v2.b32" in ptx else: assert "ld.global.b32" in ptx # triton.testing.assert_almost_equal(dst, src[:N]) @@ -1292,7 +1290,7 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non def cache_hook(*args, **kwargs): nonlocal spec_type - spec_type = kwargs["compile"]["arg_types"][0][1] + spec_type = kwargs["compile"]["signature"][0] JITFunction.cache_hook = cache_hook @triton.jit @@ -1319,7 +1317,7 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda' x = torch.tensor([3.14159], device='cuda') if overflow: - with pytest.raises(RuntimeError, match='integer overflow'): + with pytest.raises(OverflowError): kernel[(1, )](value, x) else: kernel[(1, )](value, x) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 514fbab7b..e14ea6ae7 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -78,10 +78,9 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_() configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)] kernel = triton.ops._matmul.kernel - decorators = kernel.kernel_decorators - kernel.kernel_decorators = [] - triton.autotune(configs, [])(kernel) - kernel.kernel_decorators += decorators[1:] + kernel.configs = configs + # kernel.run = kernel.run.run.run + # get matrix shape M = BLOCK_M if M is None else M N = BLOCK_N if N is None else N diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index fd95dbd38..6fad3af3d 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -7,7 +7,7 @@ import torch import triton import triton.language as tl -from triton.code_gen import JITFunction +from triton.runtime.jit import JITFunction tmpdir = ".tmp" @@ -99,16 +99,16 @@ def test_specialize(mode): reset_tmp_dir() x = torch.empty(1, dtype=torch.int32, device='cuda') function = {'enable': kernel, 'disable': kernel_nospec}[mode] - target = {'enable': 5, 'disable': 1}[mode] + target = {'enable': 3, 'disable': 1}[mode] for i in [1, 2, 4, 8, 16, 32]: function[(1,)](x, i, BLOCK=512) assert counter == target @pytest.mark.parametrize("value, value_type", [ - (-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'), - (2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'), - (2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64') + (-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'), + (2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64') ]) def test_value_specialization(value: int, value_type: str, device='cuda') -> None: @@ -120,14 +120,14 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non def get_cache_str(*args, **kwargs): nonlocal cache_str - cache_str = kwargs['key'].split('-') - triton.code_gen.JITFunction.cache_hook = get_cache_str + cache_str = kwargs["repr"] + triton.JITFunction.cache_hook = get_cache_str reset_tmp_dir() x = torch.tensor([3.14159], device='cuda') kernel[(1, )](value, x) - triton.code_gen.JITFunction.cache_hook = None + triton.JITFunction.cache_hook = None - cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1]) + cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str) spec_type = None if cache_str_match is None else cache_str_match.group(1) assert spec_type == value_type diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 37ba46efc..c620543ee 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -6,9 +6,10 @@ __version__ = '2.0.0' # or pybind11 shows `munmap_chunk(): invalid pointer` import torch # submodules -from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \ - JITFunction, Config, Autotuner, reinterpret +from .utils import * +from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface +from .runtime.jit import jit +from .compiler import compile, CompilationError from . import language -from . import code_gen from . import testing from . import ops diff --git a/python/triton/code_gen.py b/python/triton/compiler.py similarity index 50% rename from python/triton/code_gen.py rename to python/triton/compiler.py index e2956aea9..98ccc6b1a 100644 --- a/python/triton/code_gen.py +++ b/python/triton/compiler.py @@ -1,39 +1,48 @@ from __future__ import annotations import ast -import builtins +import contextlib import functools import hashlib -import inspect +import io import os -import pickle import subprocess import sys +import sysconfig import tempfile -import textwrap -import threading -import time import warnings -from typing import Dict, Set, Tuple, Union +from typing import Any, Dict, Set, Tuple, Union +import setuptools import torch from filelock import FileLock import triton import triton._C.libtriton.triton as _triton -from .tools.disasm import extract - -try: - from torch._C import _cuda_getCurrentRawStream as get_cuda_stream -except ImportError: - get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream -def current_cuda_stream(device_idx=0): - # Torch's torch.cuda.current_stream() is slow. We provide this - # function to give the user an opportunity to monkey-patch their - # own faster current stream lookup. - return get_cuda_stream(device_idx) +def str_to_ty(name): + if name[0] == "*": + ty = str_to_ty(name[1:]) + return triton.language.pointer_type(ty) + tys = { + "i1": triton.language.int1, + "fp8": triton.language.float8, + "fp16": triton.language.float16, + "bf16": triton.language.bfloat16, + "fp32": triton.language.float32, + "fp64": triton.language.float64, + "i8": triton.language.int8, + "i16": triton.language.int16, + "i32": triton.language.int32, + "i64": triton.language.int64, + "u8": triton.language.uint8, + "u16": triton.language.uint16, + "u32": triton.language.uint32, + "u64": triton.language.uint64, + "B": triton.language.int1, + } + return tys[name] def mangle_ty(ty): @@ -63,7 +72,7 @@ def mangle_ty(ty): def mangle_fn(name, arg_tys, constants): # doesn't mangle ret type, which must be a function of arg tys mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) - key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x) + key = lambda x: x.__name__ if isinstance(x, triton.runtime.JITFunction) else repr(x) mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)]) mangled_constants = mangled_constants.replace('.', '_d_') mangled_constants = mangled_constants.replace("'", '_sq_') @@ -218,7 +227,8 @@ class ValueConstructor: class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, prototypes=None, module=None, is_kernel=False): + def __init__(self, context, prototype, gscope, attributes, constants, function_name, spec_to_1=None, prototypes=None, module=None, is_kernel=False): + self.spec_to_1 = set() if spec_to_1 is None else spec_to_1 self.prototypes = dict() if prototypes is None else prototypes self.builder = _triton.ir.builder(context) self.module = _triton.ir.module('', self.builder) if module is None else module @@ -226,6 +236,7 @@ class CodeGenerator(ast.NodeVisitor): self.attributes = attributes self.constants = constants self.last_node = None + self.function_name = function_name self.is_kernel = is_kernel self.value_constructor = ValueConstructor(self.module, self.builder, gscope) @@ -260,7 +271,7 @@ class CodeGenerator(ast.NodeVisitor): return ret def visit_FunctionDef(self, node): - arg_names, kwarg_names = self.visit(node.args) + arg_names, arg_annotations, kwarg_names = self.visit(node.args) # initialize defaults for i, default_value in enumerate(node.args.defaults): arg_node = node.args.args[-i - 1] @@ -273,28 +284,27 @@ class CodeGenerator(ast.NodeVisitor): init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) self.visit(init_node) # initialize function - fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants) - self.prototypes[fn_name] = self.prototype - fn = self.module.get_or_insert_function(fn_name, self.prototype.to_ir(self.builder)) + self.prototypes[self.function_name] = self.prototype + fn = self.module.get_or_insert_function(self.function_name, self.prototype.to_ir(self.builder)) fn.set_is_kernel(self.is_kernel) arg_values = [] idx = 0 - for i, arg_name in enumerate(arg_names): + for i, (arg_name, annotation) in enumerate(zip(arg_names, arg_annotations)): if i in self.constants: cst = self.constants[i] if not isinstance(cst, triton.language.constexpr): cst = triton.language.constexpr(self.constants[i]) arg_values.append(cst) - else: - if i in self.attributes: - is_ptr = fn.args[idx].type.is_ptr() - attr = 'aligned' if is_ptr else 'multiple_of' - attr = getattr(_triton.ir.attribute_kind, attr) - attr = _triton.ir.attribute(attr, self.attributes[i]) - fn.add_attr(idx + 1, attr) - fn.args[idx].name = arg_name - arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx])) - idx += 1 + continue + if i in self.attributes: + is_ptr = fn.args[idx].type.is_ptr() + attr = 'aligned' if is_ptr else 'multiple_of' + attr = getattr(_triton.ir.attribute_kind, attr) + attr = _triton.ir.attribute(attr, self.attributes[i][1]) + fn.add_attr(idx + 1, attr) + fn.args[idx].name = arg_name + arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx])) + idx += 1 insert_pt = self.builder.get_insert_block() entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) @@ -309,20 +319,23 @@ class CodeGenerator(ast.NodeVisitor): self.builder.ret_void() else: # a bit hacky: we only know the return type at the last moment so we update type info here - self.module.reset_ret_ty(fn_name, self.last_ret.type.to_ir(self.builder)) + self.module.reset_ret_ty(self.function_name, self.last_ret.type.to_ir(self.builder)) self.prototype.ret_type = self.last_ret.type self.builder.set_insert_block(insert_pt) def visit_arguments(self, node): arg_names = [] + arg_annotations = [] for arg in node.args: - arg_names += [self.visit(arg)] + curr = self.visit(arg) + arg_names += [curr[0]] + arg_annotations += [curr[1]] kwarg_names = self.visit(node.kwarg) - return arg_names, kwarg_names + return arg_names, arg_annotations, kwarg_names def visit_arg(self, node): ast.NodeVisitor.generic_visit(self, node) - return node.arg + return node.arg, node.annotation def visit_AnnAssign(self, node): # extract attributes @@ -661,7 +674,7 @@ class CodeGenerator(ast.NodeVisitor): kws.update(self.visit(keyword)) args = [self.visit(arg) for arg in node.args] - if isinstance(fn, JITFunction): + if isinstance(fn, triton.runtime.JITFunction): from inspect import getcallargs args = getcallargs(fn.fn, *args, **kws) args = [args[name] for name in fn.arg_names] @@ -681,7 +694,7 @@ class CodeGenerator(ast.NodeVisitor): ret_type = triton.language.void prototype = triton.language.function_type(ret_type, arg_types) gscope = sys.modules[fn.fn.__module__].__dict__ - generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, prototypes=self.prototypes, module=self.module) + generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, function_name=fn_name, prototypes=self.prototypes, module=self.module) generator.visit(fn.parse()) symbol = self.module.get_function(fn_name) ret = self.builder.call(symbol, arg_vals) @@ -758,52 +771,6 @@ class CodeGenerator(ast.NodeVisitor): raise NotImplementedError("Unsupported node: {}".format(typename)) -class Binary: - def __init__(self, backend, name, asm, shared_mem, num_warps): - self.backend = backend - self.name = name - self.asm = asm - self.shared_mem = shared_mem - self.num_warps = num_warps - - -class LoadedBinary: - def __init__(self, device: int, bin: Binary): - module, kernel, n_regs, n_spills = _triton.code_gen.load_binary(bin.backend, - bin.name, - bin.asm, - bin.shared_mem, - device) - self.bin = bin - self.asm = bin.asm - self.sass = '' - self.module = module - self.kernel = kernel - self.n_regs = n_regs - self.n_spills = n_spills - self.device = device - self.shared_mem = bin.shared_mem - - def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): - _triton.runtime.enqueue(self.bin.backend, stream, self.kernel, - grid_0, grid_1, grid_2, - self.bin.num_warps * 32, 1, 1, - args, self.bin.shared_mem) - - def get_sass(self, fun=None): - if self.sass: - return self.sass - fd, path = tempfile.mkstemp() - try: - with open(fd, 'wb') as cubin: - cubin.write(self.asm['cubin']) - self.sass = extract(path, fun) - finally: - os.remove(path) - self.asm['sass'] = self.sass - return self.sass - - class CompilationError(Exception): def __init__(self, src, node): self.message = f'at {node.lineno}:{node.col_offset}:\n' @@ -833,694 +800,464 @@ class OutOfResources(Exception): return (type(self), (self.required, self.limit, self.name)) -class Kernel: - - @staticmethod - def _type_name(obj): - type_names = { - triton.language.float8: 'f8', - torch.bfloat16: 'bf16', - torch.float16: 'f16', - torch.float32: 'f32', - torch.float64: 'f64', - torch.bool: 'i1', - torch.uint8: 'u8', - torch.int8: 'i8', - torch.int16: 'i16', - torch.int32: 'i32', - torch.int64: 'i64', - triton.language.uint8: 'u8', - triton.language.uint16: 'u16', - triton.language.uint32: 'u32', - triton.language.uint64: 'u64', - } - if hasattr(obj, 'data_ptr'): - return type_names[obj.dtype] - if isinstance(obj, triton.language.constexpr): - obj = obj.value - if isinstance(obj, int): - if -2**31 <= obj < 2**31: - return 'i32' - elif 2**31 <= obj < 2**32: - return 'u32' - elif -2**63 <= obj < 2**63: - return 'i64' - elif 2**63 <= obj < 2**64: - return 'u64' - else: - raise ValueError(f'integer overflow representing {obj}') - if isinstance(obj, float): - return 'f' - if isinstance(obj, bool): - return 'B' - if isinstance(obj, str): - return 'str' - raise NotImplementedError(f'could not compute type name for {obj}') - - @staticmethod - def _to_python_ir(obj): - # convert torch.Tensor to Triton IR pointers - if hasattr(obj, 'data_ptr'): - name = Kernel._type_name(obj) - return 'ptr', name - # default path returns triton.ir.type directly - name = Kernel._type_name(obj) - return 'scalar', name - - @staticmethod - def _to_triton_ir(obj): - which, name = obj - type_map = { - 'I': triton.language.int32, - 'L': triton.language.int64, - 'f': triton.language.float32, - 'B': triton.language.int1, - 'f8': triton.language.float8, - 'f16': triton.language.float16, - 'bf16': triton.language.bfloat16, - 'f32': triton.language.float32, - 'f64': triton.language.float64, - 'i1': triton.language.int1, - 'i8': triton.language.int8, - 'i16': triton.language.int16, - 'i32': triton.language.int32, - 'i64': triton.language.int64, - 'u8': triton.language.uint8, - 'u16': triton.language.uint16, - 'u32': triton.language.uint32, - 'u64': triton.language.uint64, - } - # convert torch.Tensor to Triton IR pointers - if which == 'ptr': - elt_ty = type_map[name] - return triton.language.pointer_type(elt_ty, 1) - # default path returns triton.ir.type directly - return type_map[name] - - @staticmethod - def pow2_divisor(N): - if N % 16 == 0: - return 16 - if N % 8 == 0: - return 8 - if N % 4 == 0: - return 4 - if N % 2 == 0: - return 2 - return 1 - - def __init__(self, fn): - self.fn = fn - self.cache_key = {} - - def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages, extern_libs): - tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] - - # attributes - attributes = dict() - for i, arg in enumerate(wargs): - if i in self.fn.do_not_specialize: - continue - if isinstance(arg, int): - attributes[i] = Kernel.pow2_divisor(arg) - elif i in tensor_idxs: - addr = arg.data_ptr() - range_size = _triton.runtime.get_pointer_range_size(addr) - attributes[i] = min(Kernel.pow2_divisor(addr), - Kernel.pow2_divisor(range_size)) - # transforms ints whose value is one into constants for just-in-time compilation - constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize} - constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) - constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) - arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants] - return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, - extern_libs=extern_libs, is_manual_warmup=False) - - def __call__(self, *wargs, grid, num_warps=4, num_stages=2, extern_libs={}, **kwargs): - assert num_warps != 0 and (num_warps & (num_warps - 1)) == 0, f"num_warps={num_warps} must be a power of 2." - # handle arguments passed by name - kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} - wargs = list(wargs) - for i, pos in enumerate(sorted(kwargs)): - wargs.insert(pos + i, kwargs[pos]) - if len(wargs) != len(self.fn.arg_names): - raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") - # handle annotations - for pos, _type in self.fn.annotations.items(): - assert _type == triton.language.constexpr, "only constexpr annotations are supported for now" - wargs[pos] = _type(wargs[pos]) - # check that tensors are on GPU. - # for arg in wargs: - # if hasattr(arg, 'data_ptr'): - # assert arg.is_cuda, "All tensors must be on GPU!" - # set device (i.e., make sure torch has the context initialized) - device = torch.cuda.current_device() - # torch creates new thread for backward pass that may have uninitlialized context - # no way to know if this function should or shouldn't initialize the cuda context - # so we're being conservative here - torch.cuda.set_device(device) - if device not in self.cache_key: - cc = torch.cuda.get_device_capability(device) - cc = str(cc[0]) + '-' + str(cc[1]) - self.cache_key[device] = self.fn.cache_key + cc - cache_key = self.cache_key[device] - stream = current_cuda_stream(device) - return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names, - device, stream, self.fn.bin_cache, num_warps, num_stages, extern_libs, self.add_to_cache, - grid) +def kernel_suffix(signature, specialization): + # suffix format: + # <'c' if equal to 1><'d' if divisible by 16> + suffix = '' + for i, _ in enumerate(signature): + suffix += str(i) + if i in specialization.equal_to_1: + suffix += 'c' + if i in specialization.divisible_by_16: + suffix += 'd' + return suffix -class Launcher: - def __init__(self, kernel, grid): - self.kernel = kernel - self.grid = grid +def make_triton_ir(fn, signature, specialization, constants): + context = _triton.ir.context() + # create kernel prototype + cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + # visit kernel AST + gscope = fn.__globals__.copy() + function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)]) + tys = list(signature.values()) + new_constants = {k: True if tys[k] == "i1" else 1 for k in specialization.equal_to_1} + new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16} + all_constants = constants.copy() + all_constants.update(new_constants) + arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants] - def __call__(self, *wargs, **kwargs): - return self.kernel(*wargs, **kwargs, grid=self.grid) + prototype = triton.language.function_type(triton.language.void, arg_types) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True) + try: + generator.visit(fn.parse()) + except Exception as e: + node = generator.last_node + if node is None or isinstance(e, (NotImplementedError, CompilationError)): + raise e + raise CompilationError(fn.src, node) from e + ret = generator.module + # module takes ownership of the context + ret.context = context + return ret, generator -class Autotuner: - def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None): - ''' - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. - ''' - if not configs: - self.configs = [Config(dict(), num_warps=4, num_stages=2)] - else: - self.configs = configs - self.key_idx = [arg_names.index(k) for k in key] - self.cache = dict() - self.kernel = kernel - # hook to reset all required tensor to zeros before relaunching a kernel - self.hook = lambda args: 0 - if reset_to_zero is not None: - self.reset_idx = [arg_names.index(k) for k in reset_to_zero] - - def _hook(args): - for i in self.reset_idx: - args[i].zero_() - self.hook = _hook - self.arg_names = arg_names - # prune configs - if prune_configs_by: - perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] - if 'early_config_prune' in prune_configs_by: - early_config_prune = prune_configs_by['early_config_prune'] - else: - perf_model, top_k, early_config_prune = None, None, None - self.perf_model, self.configs_top_k = perf_model, top_k - self.early_config_prune = early_config_prune - - def _bench(self, *args, config, **meta): - # check for conflicts, i.e. meta-parameters both provided - # as kwargs and by the autotuner - conflicts = meta.keys() & config.kwargs.keys() - if conflicts: - raise ValueError( - f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols." - ) - # augment meta-parameters with tunable ones - current = dict(meta, **config.kwargs) - - def kernel_call(): - if config.pre_hook: - config.pre_hook(self.nargs) - self.hook(args) - self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) - return triton.testing.do_bench(kernel_call) - - def __call__(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - if len(self.configs) > 1: - key = tuple([args[i] for i in self.key_idx]) - if key not in self.cache: - # prune configs - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] - bench_start = time.time() - timings = {config: self._bench(*args, config=config, **kwargs) - for config in pruned_configs} - bench_end = time.time() - self.bench_time = bench_end - bench_start - self.cache[key] = builtins.min(timings, key=timings.get) - self.hook(args) - self.configs_timings = timings - config = self.cache[key] - else: - config = self.configs[0] - self.best_config = config - if config.pre_hook is not None: - config.pre_hook(self.nargs) - return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) +def make_ptx(mod: Any, device: int) -> Tuple[str, int]: + ''' + Translate TritonGPU module to PTX code. + :param mod: a TritonGPU dialect module + :return: + - PTX code + - shared memory alloaction size + ''' + return _triton.translate_triton_gpu_to_ptx(mod, device) -_version_key_lock = threading.Lock() -_version_key = None +def make_cubin(ptx, device): + ''' + Compile TritonGPU module to cubin. + :param ptx: ptx code + :param device: CUDA device + :return: str + ''' + return _triton.compile_ptx_to_cubin(ptx, device) -def version_key(): - global _version_key - - if _version_key is not None: - return _version_key - - with _version_key_lock: - if _version_key is not None: - return _version_key - - import pkgutil - contents = [] - # frontend - with open(triton.code_gen.__file__, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] - # backend - with open(triton._C.libtriton.__file__, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] - # language - language_path = os.path.join(*triton.__path__, 'language') - for lib in pkgutil.iter_modules([language_path]): - with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] - # ptxas version - try: - ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() - except Exception: - ptxas_version = '' - _version_key = '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) - return _version_key +def ptx_get_kernel_name(ptx: str) -> str: + ''' + Get kernel name from PTX code. + This Kernel name is required when launching the kernel. + ''' + # There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin. + assert ptx + for line in ptx.split('\n'): + line = line.strip() + if line.startswith('// .globl'): + return line.split()[-1] -class DependenciesFinder(ast.NodeVisitor): +def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=_triton.code_gen.instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]: + valid_outputs = ("ttir", "ttgir", "ptx", "cubin") + assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output) - def __init__(self, globals, src) -> None: - super().__init__() - self.ret = hashlib.md5(src.encode("utf-8")).hexdigest() - self.globals = globals + # triton-ir + module, _ = make_triton_ir(fn, signature, specialization, constants) + if output == "ttir": + return module - def visit_Name(self, node): - return self.globals.get(node.id, None) + assert output == "cubin" + assert torch.version.hip is None + backend = _triton.runtime.backend.CUDA + if extern_libs is None: + extern_libs = dict() + name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs) + return asm, shared_mem, name - def visit_Attribute(self, node): - lhs = self.visit(node.value) - while isinstance(lhs, ast.Attribute): - lhs = self.visit(lhs.value) - if lhs is None or lhs is triton: - return None - return getattr(lhs, node.attr) - def visit_Call(self, node): - func = self.visit(node.func) - if func is None: - return - if inspect.isbuiltin(func): - return - if func.__module__ and func.__module__.startswith('triton.'): - return - assert isinstance(func, triton.JITFunction) - if func.hash is None: - tree = ast.parse(func.src) - finder = DependenciesFinder(func.__globals__, func.src) - finder.visit(tree) - func.hash = finder.ret - self.ret = (self.ret + func.hash).encode("utf-8") - self.ret = hashlib.md5(self.ret).hexdigest() +def ty_to_cpp(ty): + if ty[0] == '*': + return "CUdeviceptr" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp32": "float", + }[ty] + + +def generate_name_initializer(signature): + src = "int i = 0;\n" + tys = signature.split(',') + for i, ty in enumerate(tys): + src + + +@contextlib.contextmanager +def quiet(): + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = io.StringIO(), io.StringIO() + try: + yield + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + + +@functools.lru_cache() +def libcuda_dir(): + loc = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[-1] + return os.path.dirname(loc) + + +def _build(name, src, path): + # add framework + extra_compile_args = [] + library_dirs = [libcuda_dir()] + include_dirs = [path, "/usr/local/cuda/include/"] + libraries = ['cuda'] + # extra arguments + extra_link_args = [] + # create extension module + ext = setuptools.Extension( + name=name, + language='c++', + sources=[src], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args + ['-O3'], + extra_link_args=extra_link_args, + library_dirs=library_dirs, + libraries=libraries, + ) + # build extension module + args = ['build_ext'] + args.append('--build-temp=' + path) + args.append('--build-lib=' + path) + args.append('-q') + args = dict( + name=name, + ext_modules=[ext], + script_args=args, + ) + # with quiet(): + setuptools.setup(**args) + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(path, '{name}{suffix}'.format(name=name, suffix=suffix)) + return so + + +def generate_torch_glue(kernel_name, constants, signature, num_warps, binaries, tmpdir): + headers = dict() + + # write all cubins to header files + assert len(binaries) == 1, "AoT compilation not yet supported" + + for bin, shmem_size, name in binaries: + assert len(name) < 1024 + initializer = f""" +const char* {name}_ptx = R"({bin["ptx"]})"; +unsigned char {name}_bin[] = {{ {','.join(map(hex, bin["cubin"]))} }}; +unsigned int {name}_shmem = {shmem_size};""" + headers[name] = os.path.join(tmpdir, f"{name}.h") + with open(headers[name], "w") as f: + f.write(initializer) + + func_init = '\n '.join(f"init_function(\"{name}\", {name}_bin, {name}_shmem, device);" for _, _, name in binaries) + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return { + 'i1': 'int32_t', + 'i32': 'int32_t', + 'i64': 'int64_t', + 'u32': 'uint32_t', + 'u64': 'uint64_t', + 'fp32': 'float', + 'fp64': 'double', + }[ty] + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "uint32_t": "I", + "int32_t": "i", + "uint64_t": "K", + "int64_t": "L", + }[ty] + + format = "iiiK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + + # generate glue code + src = "" + for bin, shmem_size, name in binaries: + src += f"#include \"{name}.h\"\n" + src += f""" +#include \"cuda.h\" +#include + +inline void gpuAssert(CUresult code, const char *file, int line) +{{ + if (code != CUDA_SUCCESS) + {{ + const char* prefix = "Triton Error [CUDA]: "; + const char* str; + cuGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + PyErr_SetString(PyExc_RuntimeError, err); + }} +}} +#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + +static CUmodule module = 0; +static CUfunction function = 0; + +static void init_function(const char* name, const unsigned char* src, size_t n_shared_bytes, int64_t device){{ + CUmodule mod; + CUfunction fun; + CUDA_CHECK(cuModuleLoadData(&mod, src)); + CUDA_CHECK(cuModuleGetFunction(&fun, mod, name)); + // set dynamic shared memory if necessary + int shared_optin; + CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device)); + if (n_shared_bytes > 49152 && shared_optin > 49152) {{ + CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); + int shared_total, shared_static; + int n_spills, n_reg; + CUDA_CHECK(cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device)); + CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); + CUDA_CHECK(cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + CUDA_CHECK(cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + CUDA_CHECK(cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); + }} + module = mod; + function = fun; +}} + +static void init_module(CUdevice device) {{ + {func_init} +}} + + +void _{kernel_name}(int gridX, int gridY, int gridZ, CUstream stream, {arg_decls}) {{ + CUcontext ctx; + CUdevice device; + CUDA_CHECK(cuStreamGetCtx(stream, &ctx)); + CUDA_CHECK(cuCtxGetDevice(&device)); + + // TODO: machine may have heterogeneous devices + if(function == 0){{ + init_module(device); + }} + void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; + if(gridX*gridY*gridZ > 0){{ + CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*{num_warps}, 1, 1, {name}_shmem, stream, params, 0)); + }} +}} + +CUdeviceptr getPointer(PyObject *obj, int idx) {{ + if (PyLong_Check(obj)) {{ + return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj); + }} + if (obj == Py_None) {{ + return (CUdeviceptr)0; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + }} + return (CUdeviceptr)PyLong_AsUnsignedLongLong(ret); + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return (CUdeviceptr)0; +}} + + +static PyObject* {kernel_name}(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + uint64_t stream; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &stream, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ + return NULL; + }} + + _{kernel_name}(gridX, gridY, gridZ, (CUstream)stream, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); + + + if(PyErr_Occurred()) {{ + return NULL; + }} + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"{kernel_name}", {kernel_name}, METH_VARARGS, "Call {kernel_name} kernel"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"{kernel_name}\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit_{kernel_name}() {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + PyObject *ptx = PyDict_New(); +""" + + for _, _, name in binaries: + src += f""" + PyObject *py_{name}_ptx = PyUnicode_FromString({name}_ptx); + PyDict_SetItemString(ptx, "{name}", py_{name}_ptx); + Py_DECREF(py_{name}_ptx); +""" + + src += """ + PyModule_AddObject(m, "ptx", ptx); + return m; +} +""" + + return src def default_cache_dir(): return os.path.join(os.environ["HOME"], ".triton", "cache") -class JITFunction: +class CacheManager: - cache_hook = None + def __init__(self, key): + self.key = key + self.bin_path = None + self.lock_path = None + # if caching is enabled, get the lock and bin path + self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir()) + if self.cache_dir: + os.makedirs(self.cache_dir, exist_ok=True) + if self.cache_dir: + self.bin_path = os.path.join(self.cache_dir, self.key + ".so") + self.lock_path = self.bin_path + ".lock" - def __init__(self, fn, version=None, inline=True, do_not_specialize=None): - # information of wrapped function - self.fn = fn - self.module = fn.__module__ - signature = inspect.signature(fn) - self.arg_names = [v.name for v in signature.parameters.values()] - self.arg_defaults = [v.default for v in signature.parameters.values()] + def has_file(self): + return self.bin_path and os.path.exists(self.bin_path) - self.version = version - self.inline = inline - self.src = textwrap.dedent(inspect.getsource(fn)) - self.src = self.src[self.src.find("def"):] - self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize - self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize] - # cache for callable driver objects (e.g. CUkernel) - self.bin_cache = dict() - self.hash = None - # JITFunction can be instantiated as kernel - # when called with a grid using __getitem__ - self.kernel_decorators = [] - self.kernel = None - # annotations - self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()} - self.__annotations__ = fn.__annotations__ - # constexprs - self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()] - # forward docs - self.__doc__ = fn.__doc__ - self.__name__ = fn.__name__ - self.__globals__ = fn.__globals__ - self.__module__ = fn.__module__ + def put(self, binary): + if self.bin_path: + assert self.lock_path is not None + with FileLock(self.lock_path): + with open(self.bin_path + ".tmp", "wb") as f: + f.write(binary) + os.rename(self.bin_path + ".tmp", self.bin_path) - @property - @functools.lru_cache() - def cache_key(self): - # TODO : hash should be attribute of `self` - if self.hash is None: - dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src) - dependencies_finder.visit(self.parse()) - self.hash = dependencies_finder.ret + version_key() - return self.hash - # we do not parse `src` in the constructor because - # the user might want to monkey-patch self.src dynamically. - # Some unit tests do this, for example. - def parse(self): - tree = ast.parse(self.src) - assert isinstance(tree, ast.Module) - assert len(tree.body) == 1 - assert isinstance(tree.body[0], ast.FunctionDef) - return tree +def make_cache_key(fn, signature, configs, constants, num_warps, num_stages): + # Get unique key for the compiled code + get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1)) + configs_key = [get_conf_key(conf) for conf in configs] + key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}" + key = hashlib.md5(key.encode("utf-8")).hexdigest() + return key - def __call__(self, *args, **kwargs): - raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel.") - # - when `.src` attribute is set, cache path needs - # to be reinitialized - # - when kernel decorators change, cached kernel - # needs to be cleared - def __setattr__(self, name, value): - if name == 'kernel_decorators': - self.kernel = None - super(JITFunction, self).__setattr__(name, value) - if name == 'src': - self.hash = None - JITFunction.cache_key.fget.cache_clear() +def make_shared_object(fn, constants, signature, num_warps, binaries, tmpdir): + src = generate_torch_glue(fn.__name__, constants, signature, num_warps, binaries, tmpdir) + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + with quiet(): + bin_path = _build(fn.__name__, src_path, tmpdir) + with open(bin_path, "rb") as f: + return f.read() - def _init_kernel(self): - if self.kernel is None: - self.kernel = Kernel(self) - for decorator in reversed(self.kernel_decorators): - self.kernel = decorator(self.kernel) - return self.kernel - def warmup(self, compile): - return self._warmup(**compile, is_manual_warmup=True) +def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None): + # we get the kernel, i.e. the first function generated in the module + if configs is None: + assert False, "automatic specialization is not supported yet" + ref, _ = make_triton_ir(fn, signature, _triton.code_gen.instance_descriptor(), constants) + fns = ref.get_functions() + configs = _triton.infer_specialization_configs(fns[0]) + assert len(configs) == 1 + # cache manager + cache_key = make_cache_key(fn, signature, configs, constants, num_warps, num_stages) + cache_manager = CacheManager(cache_key) + # retrieve cached shared object if it exists + if cache_manager.has_file(): + return CompiledKernel(fn.__name__, cache_manager.bin_path) + # compile all the configs + binaries = [] + for config in configs: + binaries.append(_compile(fn, signature, device, constants, config, num_warps, num_stages, extern_libs, "cubin")) + # generate and compile glue code into shared object + with tempfile.TemporaryDirectory() as tmpdir: + all_constants = set(constants.keys()) + all_constants.update(configs[0].equal_to_1) + so = make_shared_object(fn, all_constants, signature, num_warps, binaries, tmpdir) - def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs, is_manual_warmup): - hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() + # write shared object to cache + cache_manager.put(so) + return CompiledKernel(fn.__name__, cache_manager.bin_path) - # create cache directory - cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir()) - if cache_dir: - os.makedirs(cache_dir, exist_ok=True) - if cache_dir: - bin_cache_path = os.path.join(cache_dir, hashed_key) - bin_lock_path = bin_cache_path + ".lock" - else: - bin_cache_path = None - bin_lock_path = None +class CompiledKernel: - binary = None - if bin_cache_path and os.path.exists(bin_cache_path): - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path, 'rb') as f: - binary = pickle.load(f)["binary"] - - compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs) - if JITFunction.cache_hook is not None: - name = self.__name__ - info = key.split('-')[-3:] - num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:] - # make signature human-readable - arg_reprs = [] - for arg_name, arg_sig in zip(self.arg_names, sig): - arg_reprs.append(f'{arg_name}: {arg_sig}') - # assemble the repr - arg_reprs = ", ".join(arg_reprs) - repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})" - noop = JITFunction.cache_hook(key=key, repr=repr, fn=self, compile={"key": key, **compile}, is_manual_warmup=is_manual_warmup, already_compiled=binary is not None) - if noop: - return True - - if binary is None: - binary = self._compile(**compile) - - if bin_cache_path: - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path + ".tmp", "wb") as f: - pickle.dump({"binary": binary, "key": key}, f) - os.rename(bin_cache_path + ".tmp", bin_cache_path) - - self.bin_cache[key] = LoadedBinary(device, binary) - return False - - def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs): - # create IR module - context = _triton.ir.context() - # get just-in-time proto-type of kernel - arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] - ret_type = triton.language.void - prototype = triton.language.function_type(ret_type, arg_types) - # generate Triton-IR - # export symbols visible from self into code-generator object - gscope = self.__globals__ - generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True) - try: - generator.visit(self.parse()) - except Exception as e: - node = generator.last_node - if node is None or isinstance(e, (NotImplementedError, CompilationError)): - raise e - raise CompilationError(self.src, node) from e - # Compile to machine code - if torch.version.hip is None: - backend = _triton.runtime.backend.CUDA - else: - backend = _triton.runtime.backend.ROCM - name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, extern_libs) - max_shared_memory = _triton.runtime.max_shared_memory(backend, device) - if shared_mem > max_shared_memory: - raise OutOfResources(shared_mem, max_shared_memory, "shared memory") - return Binary(backend, name, asm, shared_mem, num_warps) + def __init__(self, fn_name, data_path): + import importlib.util + spec = importlib.util.spec_from_file_location(fn_name, data_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + self.c_wrapper = getattr(mod, fn_name) + ptx = getattr(mod, "ptx") + if len(ptx) == 1: + self.asm = {"ptx": list(ptx.values())[0]} def __getitem__(self, grid): - return Launcher(self._init_kernel(), grid) - - def __repr__(self): - return f"JITFunction({self.module}:{self.fn.__name__})" - - -class Config: - """ - An object that represents a possible kernel configuration for the auto-tuner to try. - - :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. - :type meta: dict[Str, Any] - :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if - `num_warps=8`, then each kernel instance will be automatically parallelized to - cooperatively execute using `8 * 32 = 256` threads. - :type num_warps: int - :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. - Mostly useful for matrix multiplication workloads on SM80+ GPUs. - :type num_stages: int - :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this - function are args. - """ - - def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): - self.kwargs = kwargs - self.num_warps = num_warps - self.num_stages = num_stages - self.pre_hook = pre_hook - - def __str__(self): - res = [] - for k, v in self.kwargs.items(): - res.append(f'{k}: {v}') - res.append(f'num_warps: {self.num_warps}') - res.append(f'num_stages: {self.num_stages}') - return ', '.join(res) - - -def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): - """ - Decorator for auto-tuning a :code:`triton.jit`'d function. - - .. highlight:: python - .. code-block:: python - - @triton.autotune(configs=[ - triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), - triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), - ], - key=['x_size'] # the two above configs will be evaluated anytime - # the value of x_size changes - ) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] - - :note: When all the configurations are evaluated, the kernel will run multiple time. - This means that whatever value the kernel updates will be updated multiple times. - To avoid this undesired behavior, you can use the `reset_to_zero` argument, which - reset the value of the provided tensor to `zero` before running any configuration. - - :param configs: a list of :code:`triton.Config` objects - :type configs: list[triton.Config] - :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. - :type key: list[str] - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. - :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. - :type reset_to_zero: list[str] - """ - def decorator(fn): - def wrapper(kernel): - return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by) - - fn.kernel_decorators.append(wrapper) - return fn - - return decorator - - -def heuristics(values): - """ - Decorator for specifying how the values of certain meta-parameters may be computed. - This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. - - .. highlight:: python - .. code-block:: python - - @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size - - - .param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. - each such function takes a list of positional arguments as input. - .type values: dict[str, Callable[[list[Any]], Any]] - """ - def decorator(fn): - def wrapper(kernel): - def fun(*args, **meta): - for v, heur in values.items(): - assert v not in meta - meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta}) - return kernel(*args, **meta) - return fun - - fn.kernel_decorators.append(wrapper) - return fn - - return decorator - - -def jit(*args, **kwargs): - """ - Decorator for JIT-compiling a function using the Triton compiler. - - :note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method. - - :note: This function will be compiled and run on the GPU. It will only have access to: - - * python primitives, - * objects within the triton.language package, - * arguments to this function, - * other jit'd functions - - :param fn: the function to be jit-compiled - :type fn: Callable - """ - if args: - assert len(args) == 1 - assert callable(args[0]) - return JITFunction(args[0], **kwargs) - else: - def decorator(fn): - return JITFunction(fn, **kwargs) - return decorator - -###### - -# class ForwardDeclaration: - -# def __init__(self, name, ret_ty, arg_tys) -> None: -# self.name = name -# self.ret_ty = ret_ty -# self.arg_tys = arg_tys - -# def forward_declare(name, ret_ty, arg_tys): -# return ForwardDeclaration(name, ret_ty, arg_tys) - -###### - - -def cdiv(x, y): - return (x + y - 1) // y - - -def next_power_of_2(n): - """Return the smallest power of 2 greater than or equal to n""" - n -= 1 - n |= n >> 1 - n |= n >> 2 - n |= n >> 4 - n |= n >> 8 - n |= n >> 16 - n += 1 - return n - -###### - - -class TensorWrapper: - def __init__(self, base, dtype): - self.dtype = dtype - self.base = base - self.is_cuda = base.is_cuda - self.device = base.device - - def data_ptr(self): - return self.base.data_ptr() - - def __str__(self) -> str: - return f'TensorWrapper[{self.dtype}]({self.base})' - - -def reinterpret(tensor, dtype): - if isinstance(tensor, TensorWrapper): - if dtype == tensor.base.dtype: - # Reinterpreting to the original interpretation; return the base. - return tensor.base - else: - # Reinterpreting a wrapped tensor to a different type. - return TensorWrapper(tensor.base, dtype) - elif isinstance(tensor, torch.Tensor): - # A new wrapper is needed around an unwrapped tensor. - return TensorWrapper(tensor, dtype) - else: - raise TypeError(f'Cannot reinterpret a {type(tensor)}.') + def runner(*args, stream=None): + if stream is None: + stream = torch.cuda.current_stream().cuda_stream + self.c_wrapper(grid[0], grid[1], grid[2], stream, *args) + return runner diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e52a488b2..63a9ab7f2 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -420,7 +420,7 @@ class tensor: self.numel = 1 for s in self.shape: self.numel *= s - is_pow2 = (self.numel and (not(self.numel & (self.numel - 1)))) + is_pow2 = (self.numel and (not (self.numel & (self.numel - 1)))) if not is_pow2: raise ValueError("Triton tensors must have a power-of-two number of elements") self.numel = constexpr(self.numel) diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index bb915be13..33223b72d 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -18,8 +18,8 @@ def num_warps(n): @triton.jit def _blocksparse_softmax_fwd( - Out, A, stride_xz, LUT, - R, extent, stride_zr, stride_hr, # relative attention + Out, A, LUT, R, stride_xz, + extent, stride_zr, stride_hr, # relative attention scale, is_causal, ROW_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -164,8 +164,8 @@ class _softmax(torch.autograd.Function): # enqueue kernel out = torch.empty_like(a) _blocksparse_softmax_fwd[grid]( - out, a, a.stride(0), lut, - rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn + out, a, lut, rel_logits, a.stride(0), + rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn scale, is_causal, BLOCK_SIZE=block, diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index f1ac78849..0ffcc1677 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -26,9 +26,6 @@ def get_configs_io_bound(): return configs -@triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, -}) @triton.autotune( configs=[ # basic configs for compute-bound matmuls @@ -59,6 +56,9 @@ def get_configs_io_bound(): 'top_k': 10 }, ) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) @triton.jit def _kernel(A, B, C, M, N, K, stride_am, stride_ak, diff --git a/python/triton/runtime/__init__.py b/python/triton/runtime/__init__.py new file mode 100644 index 000000000..d9946c27c --- /dev/null +++ b/python/triton/runtime/__init__.py @@ -0,0 +1,2 @@ +from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401 +from .jit import JITFunction, KernelInterface, version_key # noqa: F401 diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py new file mode 100644 index 000000000..2175501b6 --- /dev/null +++ b/python/triton/runtime/autotuner.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +import builtins +import time +from typing import Dict + +from ..testing import do_bench +from .jit import KernelInterface + + +class Autotuner(KernelInterface): + def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None): + ''' + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + ''' + if not configs: + self.configs = [Config(dict(), num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.cache = dict() + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] + if 'early_config_prune' in prune_configs_by: + early_config_prune = prune_configs_by['early_config_prune'] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) + return do_bench(kernel_call) + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple([args[i] for i in self.key_idx]) + if key not in self.cache: + # prune configs + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type meta: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_stages: int + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_stages = num_stages + self.pre_hook = pre_hook + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f'{k}: {v}') + res.append(f'num_warps: {self.num_warps}') + res.append(f'num_stages: {self.num_stages}') + return ', '.join(res) + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + + :note: When all the configurations are evaluated, the kernel will run multiple time. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + reset the value of the provided tensor to `zero` before running any configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + """ + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + + + .param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + .type values: dict[str, Callable[[list[Any]], Any]] + """ + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py new file mode 100644 index 000000000..025f268ac --- /dev/null +++ b/python/triton/runtime/jit.py @@ -0,0 +1,415 @@ +from __future__ import annotations, division + +import ast +import functools +import hashlib +import inspect +import os +import subprocess +import textwrap +from collections import namedtuple + +import torch + +import triton + +try: + from torch._C import _cuda_getCurrentRawStream as get_cuda_stream +except ImportError: + get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + """ + + def __init__(self, globals, src) -> None: + super().__init__() + self.ret = hashlib.md5(src.encode("utf-8")).hexdigest() + self.globals = globals + + def visit_Name(self, node): + return self.globals.get(node.id, None) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or lhs is triton: + return None + return getattr(lhs, node.attr) + + def visit_Call(self, node): + func = self.visit(node.func) + if func is None: + return + if inspect.isbuiltin(func): + return + if func.__module__ and func.__module__.startswith('triton.'): + return + assert isinstance(func, JITFunction) + if func.hash is None: + tree = ast.parse(func.src) + finder = DependenciesFinder(func.__globals__, func.src) + finder.visit(tree) + func.hash = finder.ret + self.ret = (self.ret + func.hash).encode("utf-8") + self.ret = hashlib.md5(self.ret).hexdigest() + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + + +@functools.lru_cache() +def version_key(): + import pkgutil + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + with open(triton.compiler.__file__, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # backend + with open(triton._C.libtriton.__file__, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # language + language_path = os.path.join(*triton.__path__, 'language') + for lib in pkgutil.iter_modules([language_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # ptxas version + try: + ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() + except Exception: + ptxas_version = '' + return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) + + +class KernelInterface: + + def __getitem__(self, grid): + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + def launcher(*args, **kwargs): + return self.run(*args, grid=grid, **kwargs) + return launcher + + +class JITFunction(KernelInterface): + + cache_hook = None + divisibility = 16 + + @staticmethod + def _key_of(arg): + if hasattr(arg, "dtype"): + return arg.dtype + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -2**31 <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**31 <= arg and arg <= 2**32 - 1: + return "u32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return 'fp32' + elif arg is None: + return None + else: + raise TypeError(f'Unsupported type {type(arg)} for {arg}') + + @staticmethod + def _spec_of(arg): + if hasattr(arg, "data_ptr"): + return (arg.data_ptr() % JITFunction.divisibility == 0) + elif isinstance(arg, int): + return (arg % 16 == 0, arg == 1) + return (arg is None, ) + + def _get_config(self, *args): + def is_divisible_by_16(x): + if hasattr(x, "data_ptr"): + return x.data_ptr() % JITFunction.divisibility == 0 + elif isinstance(x, int): + return x % JITFunction.divisibility == 0 + if x is None: + return True + return False + divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize} + equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize} + return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1)) + # return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1) + + @staticmethod + def _type_of(key): + if isinstance(key, (torch.dtype, triton.language.dtype)): + ty = { + torch.bool: 'i1', + torch.float16: 'fp16', + torch.bfloat16: 'bf16', + torch.float32: 'fp32', + torch.float64: 'fp64', + torch.uint8: 'u8', + torch.int8: 'i8', + torch.int16: 'i16', + torch.int32: 'i32', + torch.int64: 'i64', + + triton.language.uint8: 'u8', + triton.language.uint16: 'u16', + triton.language.uint32: 'u32', + triton.language.uint64: 'u64', + triton.language.float8: 'fp8', + }[key] + return f'*{ty}' + if key is None: + return '*i8' + assert isinstance(key, str) + return key + + def _make_signature(self, sig_key): + signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)]) + return signature + + def _make_constants(self, constexpr_key): + constants = {i: k for i, k in zip(self.constexprs, constexpr_key)} + return constants + + def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs): + if JITFunction.cache_hook is None: + return False + name = self.fn.__name__ + module = self.fn.__module__ + arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])]) + repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})" + key = str(key) + + class LegacyCompiler: + def __init__(self, module, name): + self.module = module + self.name = name + pass + + kwargs = dict(signature=signature, device=device, constants=constants, + num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, + configs=configs) + + return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False) + + def _make_launcher(self): + regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs] + constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs] + args = ', '.join(regular_args) + # cache key for regular argument type + sig_keys = ', '.join([f'_key_of({arg})' for arg in regular_args]) + # cache key for constexpr argument values + constexpr_keys = ', '.join(constexpr_args) + # cache key for argument specialization + specializations = [] + for i, arg in enumerate(regular_args): + if i in self.do_not_specialize: + continue + specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") ' + f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) ' + f'else (False,)'] + spec_keys = ', '.join(specializations) + grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) + + src = f""" +def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None): + sig_key = {sig_keys}, + constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()} + spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()} + key = (version_key, sig_key, constexpr_key, spec_key) + if not extern_libs is None: + key = (key, tuple(extern_libs.items())) + assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2" + if callable(grid): + grid = grid({{{grid_args}}}) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + device = torch.cuda.current_device() + torch.cuda.set_device(device) + if stream is None: + stream = get_cuda_stream(device) + try: + bin = cache[key] + bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args}) + return bin + # kernel not cached -- compile + except KeyError: + # build dict of constant values + args = [{args}] + configs = self._get_config(*args), + constants = self._make_constants(constexpr_key) + constants.update({{i: None for i, arg in enumerate(args) if arg is None}}) + constants.update({{i: 1 for i in configs[0].equal_to_1}}) + # build kernel signature -- doesn't include specialized arguments + all_args = {', '.join([f'{arg}' for arg in self.arg_names])}, + signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }} + # build stub signature -- includes arguments that are specialized + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {i} is not supported") + device = 0 + if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): + bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) + bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args) + self.cache[key] = bin + return bin + return None +""" + scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream, + "self": self, "_spec_of": self._spec_of, "_key_of": self._key_of, + "cache": self.cache, "triton": triton, "torch": torch} + exec(src, scope) + return scope[self.fn.__name__] + + def __init__(self, fn, version=None, do_not_specialize=None): + self.fn = fn + self.module = fn.__module__ + self.version = version + # function signature information + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + self.has_defaults = any([v.default != inspect._empty for v in signature.parameters.values()]) + # specialization hints + self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize + self.do_not_specialize = set([self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize]) + # function source code (without decorators) + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[self.src.find("def"):] + # cache of just-in-time compiled kernels + self.cache = dict() + self.hash = None + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel_decorators = [] + self.kernel = None + # annotations + self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()} + self.__annotations__ = fn.__annotations__ + # index of constexprs + self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()] + # launcher + self.run = self._make_launcher() + # re-use docs of wrapped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + @property + @functools.lru_cache() + def cache_key(self): + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + version_key() + return self.hash + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Our unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + def __setattr__(self, name, value): + # - when kernel decorators change, cached kernel + # needs to be cleared + if name == 'kernel_decorators': + self.kernel = None + super(JITFunction, self).__setattr__(name, value) + # - when `.src` attribute is set, cache path needs + # to be reinitialized + if name == 'src': + self.hash = None + JITFunction.cache_key.fget.cache_clear() + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__name__})" + + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + + +def jit(*args, **kwargs): + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * objects within the triton.language package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + if args: + assert len(args) == 1 + assert callable(args[0]) + return JITFunction(args[0], **kwargs) + else: + def decorator(fn): + return JITFunction(fn, **kwargs) + return decorator + + +class TensorWrapper: + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.is_cuda = base.is_cuda + self.device = base.device + + def data_ptr(self): + return self.base.data_ptr() + + def __str__(self) -> str: + return f'TensorWrapper[{self.dtype}]({self.base})' + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif isinstance(tensor, torch.Tensor): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f'Cannot reinterpret a {type(tensor)}.') diff --git a/python/triton/testing.py b/python/triton/testing.py index 594edcbf2..2c9ece2fe 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -7,7 +7,7 @@ from contextlib import contextmanager import torch import triton._C.libtriton.triton as _triton -from .code_gen import OutOfResources +from .compiler import OutOfResources try: import triton._C.libtriton.cutlass as _cutlass diff --git a/python/triton/utils.py b/python/triton/utils.py new file mode 100644 index 000000000..f446dd06a --- /dev/null +++ b/python/triton/utils.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import torch + + +def cdiv(x, y): + return (x + y - 1) // y + + +def next_power_of_2(n): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n += 1 + return n + + +class TensorWrapper: + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.is_cuda = base.is_cuda + self.device = base.device + + def data_ptr(self): + return self.base.data_ptr() + + def __str__(self) -> str: + return f'TensorWrapper[{self.dtype}]({self.base})' + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif isinstance(tensor, torch.Tensor): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f'Cannot reinterpret a {type(tensor)}.')