diff --git a/python/examples/empty.py b/python/examples/empty.py index b17d58ca3..df313fb85 100644 --- a/python/examples/empty.py +++ b/python/examples/empty.py @@ -1,3 +1,5 @@ +import torch + import triton import triton.language as tl @@ -7,4 +9,5 @@ def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr): pass -ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir") +X = torch.randn(1, device="cuda") +pgm = kernel[(1,)](X, 1, 1, BLOCK=1024) diff --git a/python/src/triton.cc b/python/src/triton.cc index 6810d72fd..52dffd1ae 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -544,50 +544,6 @@ void init_triton_runtime(py::module &&m) { /*****************************************************************************/ typedef std::map asm_map_t; -// --------------------------------------- -// Load provided assembly code into driver -// --------------------------------------- - -// CUDA -std::tuple cu_load_binary(const std::string &name, - asm_map_t &asm_map, - size_t n_shared_bytes, - uint64_t dev) { - // load assembly - std::string assembly; - if (asm_map.find("cubin") != asm_map.end()) - assembly = py::cast(asm_map["cubin"]); - else - assembly = py::cast(asm_map["ptx"]); - // create driver handles - CUfunction fun; - CUmodule mod; - drv::dispatch::cuModuleLoadData(&mod, assembly.c_str()); - drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str()); - // set dynamic shared memory if necessary - int shared_optin; - drv::dispatch::cuDeviceGetAttribute( - &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, - dev); - if (n_shared_bytes > 49152 && shared_optin > 49152) { - drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED); - int shared_total, shared_static; - int n_spills, n_reg; - drv::dispatch::cuDeviceGetAttribute( - &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, - dev); - drv::dispatch::cuFuncGetAttribute(&shared_static, - CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun); - drv::dispatch::cuFuncGetAttribute(&n_spills, - CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun); - drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun); - drv::dispatch::cuFuncSetAttribute( - fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - shared_optin - shared_static); - } - return std::make_tuple((uint64_t)mod, (uint64_t)fun); -} - /*****************************************************************************/ /* Python bindings for triton::ir */ /*****************************************************************************/ @@ -1728,11 +1684,41 @@ void init_triton_translation(py::module &m) { m.def( "load_binary", - [](backend_t backend, const std::string &name, asm_map_t &asm_map, - size_t n_shared_bytes, uint64_t dev) { + [](const std::string &name, const std::string &data, + size_t n_shared_bytes, uint64_t device) { py::gil_scoped_release allow_threads; - assert(backend == CUDA); // Only CUDA is supported now. - return cu_load_binary(name, asm_map, n_shared_bytes, dev); + // create driver handles + CUfunction fun; + CUmodule mod; + drv::dispatch::cuModuleLoadData(&mod, data.c_str()); + drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str()); + // get allocated registers and spilled registers from the function + int n_regs = 0; + int n_spills = 0; + drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, + fun); + drv::dispatch::cuFuncGetAttribute( + &n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun); + n_spills /= 4; + // set dynamic shared memory if necessary + int shared_optin; + drv::dispatch::cuDeviceGetAttribute( + &shared_optin, + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device); + if (n_shared_bytes > 49152 && shared_optin > 49152) { + drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED); + int shared_total, shared_static; + drv::dispatch::cuDeviceGetAttribute( + &shared_total, + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device); + drv::dispatch::cuFuncGetAttribute( + &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun); + drv::dispatch::cuFuncSetAttribute( + fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_optin - shared_static); + } + return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, + (uint64_t)n_spills); }, py::return_value_policy::take_ownership); } diff --git a/python/tests/test_cast.py b/python/tests/test_cast.py index cc7793aa0..9b5513aa3 100644 --- a/python/tests/test_cast.py +++ b/python/tests/test_cast.py @@ -2,8 +2,9 @@ import triton import triton.language as tl +# TODO: function with no arguments don't work @triton.jit -def cast_check(): +def cast_check(X): zero_0d = tl.zeros([], dtype=tl.float32) zero_1d = tl.zeros([2], dtype=tl.float32) zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32) @@ -48,9 +49,9 @@ def cast_check(): def test_cast_check(): - kernel = triton.compile(cast_check, - signature="", - device=0, - output="ttir") + kernel = triton.compiler._compile(cast_check, + signature="*fp32", + device=0, + output="ttgir") assert (kernel) # TODO: Check types of the results diff --git a/python/tests/test_compiler.py b/python/tests/test_compiler.py index ff4ca1bcf..3c34eae93 100644 --- a/python/tests/test_compiler.py +++ b/python/tests/test_compiler.py @@ -2,7 +2,6 @@ import torch import triton import triton.language as tl -import triton.runtime as runtime # trigger the torch.device implicitly to ensure cuda context initialization torch.zeros([10], device=torch.device('cuda')) @@ -16,30 +15,18 @@ def empty_kernel(X, stride_xm, BLOCK: tl.constexpr): def test_empty_kernel_cubin_compile(): device = torch.cuda.current_device() - cubin = triton.compile(empty_kernel, - "*fp32,i32,i32", - device=device, - constants={"BLOCK": 256}, - output="cubin") + kernel = triton.compile(empty_kernel, + "*fp32,i32,i32", + device=device, + constants={"BLOCK": 256}) - print('cubin size:', len(cubin)) - assert len(cubin) > 0 + assert len(kernel.asm["cubin"]) > 0 def test_empty_kernel_launch(): - device = torch.cuda.current_device() - binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32", - constants={"BLOCK": 256}, - num_warps=4, - num_stages=3) grid = lambda META: ( triton.cdiv(1024, META['BLOCK']) * triton.cdiv(1024, META['BLOCK']), ) A = torch.zeros([1024], device="cuda") - runtime.launch_kernel(kernel=binary, - grid=grid, - device=device, - X=A, - stride_xm=256, - BLOCK=tl.constexpr(256)) + empty_kernel[grid](X=A, stride_xm=256, BLOCK=256) diff --git a/python/tests/test_math_ops.py b/python/tests/test_math_ops.py index f5ed9fdf9..c464818fb 100644 --- a/python/tests/test_math_ops.py +++ b/python/tests/test_math_ops.py @@ -23,11 +23,11 @@ def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr): def test_empty_kernel_cubin_compile(): - kernel = triton.compile(math_kernel, - "*fp32,*fp32,*fp32,*fp32,i32", - device=0, - constants={"BLOCK_SIZE": 256}, - output="ttgir") # "cubin" + kernel = triton.compiler._compile(math_kernel, + "*fp32,*fp32,*fp32,*fp32,i32", + device=0, + constants={"BLOCK_SIZE": 256}, + output="ttgir") # "cubin" assert kernel # TODO: Check if the values are correct. # TODO: Cover all the math operators diff --git a/python/tests/test_transpose.py b/python/tests/test_transpose.py index 3daee7bfe..8875b7feb 100644 --- a/python/tests/test_transpose.py +++ b/python/tests/test_transpose.py @@ -4,7 +4,6 @@ from torch.testing import assert_allclose import triton import triton.language as tl -import triton.runtime as runtime @triton.jit @@ -40,29 +39,9 @@ def kernel(x_ptr, stride_xm, [2, 128, 64] ]) def test_convert_layout_impl(NUM_WARPS, SIZE_M, SIZE_N): - # TODO: this is to initialize the cuda context since it is not properly - # dealed with in the existing runtime, remove this when the runtime - # is updated - torch.zeros([10], device=torch.device('cuda')) - device = torch.cuda.current_device() - binary = runtime.build_kernel(kernel, - "*fp32,i32,*fp32,i32", - constants={"SIZE_M": SIZE_M, - "SIZE_N": SIZE_N}, - num_warps=NUM_WARPS, - num_stages=3) grid = lambda META: (1, ) - x = torch.randn((SIZE_M, SIZE_N), device='cuda', dtype=torch.float32) z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype) - runtime.launch_kernel(kernel=binary, - device=device, - grid=grid, - x_ptr=x, - stride_xm=x.stride(0), - z_ptr=z, - stride_zn=z.stride(0), - SIZE_M=tl.constexpr(SIZE_M), - SIZE_N=tl.constexpr(SIZE_N)) + kernel[grid](x_ptr=x, stride_xm=x.stride(0), z_ptr=z, stride_zn=z.stride(0), SIZE_M=SIZE_M, SIZE_N=SIZE_N, num_warps=NUM_WARPS) golden_z = torch.t(x) assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7) diff --git a/python/tests/test_vecadd_no_scf.py b/python/tests/test_vecadd_no_scf.py index 6f72171ab..161b11b1f 100644 --- a/python/tests/test_vecadd_no_scf.py +++ b/python/tests/test_vecadd_no_scf.py @@ -3,7 +3,6 @@ from torch.testing import assert_allclose import triton import triton.language as tl -import triton.runtime as runtime def vecadd_no_scf_tester(num_warps, block_size): @@ -22,27 +21,13 @@ def vecadd_no_scf_tester(num_warps, block_size): z_ptrs = z_ptr + offset tl.store(z_ptrs, z) - torch.zeros([10], device=torch.device('cuda')) - device = torch.cuda.current_device() - binary = runtime.build_kernel(kernel, "*fp32,*fp32,*fp32,i32", - constants={"BLOCK_SIZE_N": block_size}, - num_warps=num_warps, - num_stages=3) - x = torch.randn((block_size,), device='cuda', dtype=torch.float32) y = torch.randn((block_size,), device='cuda', dtype=torch.float32) z = torch.empty((block_size,), device=x.device, dtype=x.dtype) - assert x.shape.numel() % block_size == 0, "Only test load without mask here" grid = lambda EA: (x.shape.numel() // block_size,) + kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE_N=block_size, num_warps=num_warps) - runtime.launch_kernel(kernel=binary, - grid=grid, - device=device, - x_ptr=x, - y_ptr=y, - z_ptr=z, - BLOCK_SIZE_N=tl.constexpr(block_size)) golden_z = x + y assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7) diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 6963bf25c..c620543ee 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -7,8 +7,9 @@ __version__ = '2.0.0' import torch # submodules from .utils import * -from .runtime import jit, Config, autotune, heuristics -from .compiler import compile +from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface +from .runtime.jit import jit +from .compiler import compile, CompilationError from . import language from . import testing from . import ops diff --git a/python/triton/compiler.py b/python/triton/compiler.py index e18f61ec2..a97252f75 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1,10 +1,26 @@ from __future__ import annotations import ast +import contextlib +import functools +import hashlib +import io +import json +import os +import shutil +import subprocess import sys +import sysconfig +import tempfile import warnings +from collections import namedtuple +from sysconfig import get_paths from typing import Any, Dict, Tuple, Union +import setuptools +import torch +from filelock import FileLock + import triton import triton._C.libtriton.triton as _triton @@ -85,7 +101,7 @@ class enter_sub_region: class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()): + def __init__(self, context, prototype, gscope, attributes, constants, function_name, module=None, is_kernel=False, function_types=dict()): self.builder = _triton.ir.builder(context) self.module = self.builder.create_module() if module is None else module self.function_ret_types = function_types @@ -94,6 +110,7 @@ class CodeGenerator(ast.NodeVisitor): self.lscope = dict() self.attributes = attributes self.constants = constants + self.function_name = function_name self.is_kernel = is_kernel self.last_node = None self.builtins = { @@ -194,8 +211,7 @@ class CodeGenerator(ast.NodeVisitor): init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) self.visit(init_node) # initialize function - fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants) - fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder)) + fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder)) self.module.push_back(fn) entry = fn.add_entry_block() arg_values = [] @@ -206,10 +222,10 @@ class CodeGenerator(ast.NodeVisitor): if not isinstance(cst, triton.language.constexpr): cst = triton.language.constexpr(self.constants[i]) arg_values.append(cst) + continue else: - pass if i in self.attributes: - fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i]) + fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i][1]) arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx])) idx += 1 @@ -746,21 +762,40 @@ class OutOfResources(Exception): return (type(self), (self.required, self.limit, self.name)) -def make_triton_ir(fn, signature, constants=dict(), attributes=dict()): +def kernel_suffix(signature, specialization): + # suffix format: + # <'c' if equal to 1><'d' if divisible by 16> + suffix = '' + for i, _ in enumerate(signature): + suffix += str(i) + if i in specialization.equal_to_1: + suffix += 'c' + if i in specialization.divisible_by_16: + suffix += 'd' + return suffix + +# ------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ + + +def make_triton_ir(fn, signature, specialization, constants): context = _triton.ir.context() context.load_triton() # create kernel prototype - constants = {fn.arg_names.index(name): value for name, value in constants.items()} - attributes = {fn.arg_names.index(name): value for name, value in attributes.items()} - if signature.replace(' ', '') != '': - arg_types = signature.replace(' ', '').split(',') - arg_types = [str_to_ty(x) for x in arg_types] - else: - arg_types = [] - prototype = triton.language.function_type([], arg_types) + cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} # visit kernel AST gscope = fn.__globals__.copy() - generator = CodeGenerator(context, prototype, gscope=gscope, constants=constants, attributes=attributes, is_kernel=True) + function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)]) + tys = list(signature.values()) + new_constants = {k: True if tys[k] == "i1" else 1 for k in specialization.equal_to_1} + new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16} + all_constants = constants.copy() + all_constants.update(new_constants) + arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants] + + prototype = triton.language.function_type([], arg_types) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True) try: generator.visit(fn.parse()) except Exception as e: @@ -769,9 +804,9 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()): raise e raise CompilationError(fn.src, node) from e ret = generator.module - # module takes ownership of the MLIR context + # module takes ownership of the context ret.context = context - return ret + return ret, generator def optimize_triton_ir(mod): @@ -842,14 +877,21 @@ def ptx_get_kernel_name(ptx: str) -> str: return line.split()[-1] -def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=dict(), num_warps: int = 4, num_stages: int = 3, output: str = "ttgir") -> Tuple[str, int, str]: +instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()]) + + +def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]: + if isinstance(signature, str): + signature = {k: v.strip() for k, v in enumerate(signature.split(","))} valid_outputs = ("ttir", "ttgir", "ptx", "cubin") assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output) + # triton-ir - module = make_triton_ir(fn, signature, constants, attributes) + module, _ = make_triton_ir(fn, signature, specialization, constants) module = optimize_triton_ir(module) if output == "ttir": return module.str() + # tritongpu-ir module = make_tritongpu_ir(module, num_warps) module = optimize_tritongpu_ir(module, num_stages) @@ -865,6 +907,357 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=d cubin = make_cubin(ptx, device) if output == "cubin": - return cubin, shem_size, kernel_name + return cubin, ptx, shem_size, kernel_name assert False + + +# ------------------------------------------------------------------------------ +# compiler +# ------------------------------------------------------------------------------ + + +def ty_to_cpp(ty): + if ty[0] == '*': + return "CUdeviceptr" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp32": "float", + }[ty] + + +def generate_name_initializer(signature): + src = "int i = 0;\n" + tys = signature.split(',') + for i, ty in enumerate(tys): + src + + +def binary_name_to_header_name(name): + if len(name) > 128: + # avoid filename too long errors (filename limit is 255) + name = "kernel_" + hashlib.sha256(name.encode("utf-8")).hexdigest() + return f"{name}.h" + + +def generate_launcher(identifier, constants, signature): + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return { + 'i1': 'int32_t', + 'i32': 'int32_t', + 'i64': 'int64_t', + 'u32': 'uint32_t', + 'u64': 'uint64_t', + 'fp32': 'float', + 'fp64': 'double', + }[ty] + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "uint32_t": "I", + "int32_t": "i", + "uint64_t": "K", + "int64_t": "L", + }[ty] + + format = "iiiiiKK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + + # generate glue code + src = f""" +#include \"cuda.h\" +#include +static inline void gpuAssert(CUresult code, const char *file, int line) +{{ + if (code != CUDA_SUCCESS) + {{ + const char* prefix = "Triton Error [CUDA]: "; + const char* str; + cuGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + PyErr_SetString(PyExc_RuntimeError, err); + }} +}} +#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} +void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{ + void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; + if(gridX*gridY*gridZ > 0){{ + CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); + }} +}} +static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{ + if (PyLong_Check(obj)) {{ + return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj); + }} + if (obj == Py_None) {{ + return (CUdeviceptr)0; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + }} + return (CUdeviceptr)PyLong_AsUnsignedLongLong(ret); + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return (CUdeviceptr)0; +}} +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + int num_warps; + int shared_memory; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ + return NULL; + }} + _launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); + if(PyErr_Occurred()) {{ + return NULL; + }} + // return None + Py_INCREF(Py_None); + return Py_None; +}} +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; +PyMODINIT_FUNC PyInit_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + + return src + + +def default_cache_dir(): + return os.path.join(os.environ["HOME"], ".triton", "cache") + + +class CacheManager: + + def __init__(self, key): + self.key = key + self.lock_path = None + # create cache directory if it doesn't exist + self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir()) + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + + def _make_path(self, filename): + return os.path.join(self.cache_dir, filename) + + def has_file(self, filename): + if not self.cache_dir: + return False + return os.path.exists(self._make_path(filename)) + + def put(self, data, filename, binary=True): + if not self.cache_dir: + return + assert self.lock_path is not None + filepath = self._make_path(filename) + with FileLock(self.lock_path): + # use tempfile to be robust against program interruptions + mode = "wb" if binary else "w" + with open(filepath + ".tmp", mode) as f: + f.write(data) + os.rename(filepath + ".tmp", filepath) + + +# utilties for generating and compiling C wrappers + + +@functools.lru_cache() +def libcuda_dir(): + loc = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[-1] + return os.path.dirname(loc) + + +@contextlib.contextmanager +def quiet(): + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = io.StringIO(), io.StringIO() + try: + yield + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + + +def _build(name, src, srcdir): + cuda_lib_dir = libcuda_dir() + cu_include_dir = "/usr/local/cuda/include" + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + # TODO: support more things here. + clang = shutil.which("clang") + gcc = shutil.which("gcc") + cc = gcc if gcc is not None else clang + py_include_dir = get_paths()["include"] + ret = subprocess.check_call([cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{cuda_lib_dir}", "-lcuda", "-o", so]) + if ret == 0: + return so + # fallback on setuptools + extra_compile_args = [] + library_dirs = [cuda_lib_dir] + include_dirs = [srcdir, cu_include_dir] + libraries = ['cuda'] + # extra arguments + extra_link_args = [] + # create extension module + ext = setuptools.Extension( + name=name, + language='c', + sources=[src], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args + ['-O3'], + extra_link_args=extra_link_args, + library_dirs=library_dirs, + libraries=libraries, + ) + # build extension module + args = ['build_ext'] + args.append('--build-temp=' + srcdir) + args.append('--build-lib=' + srcdir) + args.append('-q') + args = dict( + name=name, + ext_modules=[ext], + script_args=args, + ) + with quiet(): + setuptools.setup(**args) + return so + + +def make_so_cache_key(signature, constants): + # Get unique key for the compiled code + signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} + key = f"{''.join(signature.values())}{constants}" + key = hashlib.md5(key.encode("utf-8")).hexdigest() + return key + + +def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_stages): + # Get unique key for the compiled code + get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1)) + configs_key = [get_conf_key(conf) for conf in configs] + key = f"{fn_hash}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}" + key = hashlib.md5(key.encode("utf-8")).hexdigest() + return key + + +def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None): + if isinstance(signature, str): + signature = {k: v.strip() for k, v in enumerate(signature.split(","))} + # we get the kernel, i.e. the first function generated in the module + if configs is None: + configs = [instance_descriptor()] + assert len(configs) == 1 + # cache manager + name = fn.__name__ + # name of files that are cached + so_cache_key = make_so_cache_key(signature, constants) + so_cache_manager = CacheManager(so_cache_key) + so_name = f"{name}.so" + # retrieve stub from cache if it exists + if not so_cache_manager.has_file(so_name): + with tempfile.TemporaryDirectory() as tmpdir: + src = generate_launcher(name, constants, signature) + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + so = _build(fn.__name__, src_path, tmpdir) + with open(so, "rb") as f: + so_cache_manager.put(f.read(), so_name, binary=True) + + # retrieve cached shared object if it exists + fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages) + fn_cache_manager = CacheManager(fn_cache_key) + ptx_name = f"{name}.ptx" + cubin_name = f"{name}.cubin" + data_name = f"{name}.json" + if not fn_cache_manager.has_file(cubin_name) or \ + not fn_cache_manager.has_file(data_name) or \ + not fn_cache_manager.has_file(ptx_name): + cubin, ptx, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin") + metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages} + fn_cache_manager.put(cubin, cubin_name) + fn_cache_manager.put(ptx, ptx_name, binary=False) + fn_cache_manager.put(json.dumps(metadata), data_name, binary=False) + + return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir) + + +class CompiledKernel: + + def __init__(self, fn_name, so_path, cache_dir): + + # initialize launcher + import importlib.util + spec = importlib.util.spec_from_file_location("launcher", so_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + self.c_wrapper = getattr(mod, "launch") + # initialize metadata + with open(os.path.join(cache_dir, f"{fn_name}.json")) as f: + metadata = json.load(f) + self.shared = metadata["shared"] + self.num_warps = metadata["num_warps"] + self.num_stages = metadata["num_stages"] + # initialize asm dict + self.asm = dict() + with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f: + self.asm["cubin"] = f.read() + with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f: + self.asm["ptx"] = f.read() + + device = torch.cuda.current_device() + mod, func, n_regs, n_spills = _triton.load_binary(metadata["name"], self.asm["cubin"], self.shared, device) + self.cu_module = mod + self.cu_function = func + + def __getitem__(self, grid): + def runner(*args, stream=None): + if stream is None: + stream = torch.cuda.current_stream().cuda_stream + self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args) + return diff --git a/python/triton/runtime/__init__.py b/python/triton/runtime/__init__.py index 3be401163..d9946c27c 100644 --- a/python/triton/runtime/__init__.py +++ b/python/triton/runtime/__init__.py @@ -1,2 +1,2 @@ -from .autotuner import Config, autotune, heuristics # noqa: F401 -from .jit import JITFunction, build_kernel, jit, launch_kernel # noqa: F401 +from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401 +from .jit import JITFunction, KernelInterface, version_key # noqa: F401 diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 010af80b4..522964d56 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -5,10 +5,11 @@ import time from typing import Dict from ..testing import do_bench +from .jit import KernelInterface -class Autotuner: - def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None): +class Autotuner(KernelInterface): + def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None): ''' :param prune_configs_by: a dict of functions that are used to prune configs, fields: 'perf_model': performance model used to predicate running time with different configs, returns running time @@ -21,7 +22,6 @@ class Autotuner: self.configs = configs self.key_idx = [arg_names.index(k) for k in key] self.cache = dict() - self.kernel = kernel # hook to reset all required tensor to zeros before relaunching a kernel self.hook = lambda args: 0 if reset_to_zero is not None: @@ -41,6 +41,7 @@ class Autotuner: perf_model, top_k, early_config_prune = None, None, None self.perf_model, self.configs_top_k = perf_model, top_k self.early_config_prune = early_config_prune + self.fn = fn def _bench(self, *args, config, **meta): # check for conflicts, i.e. meta-parameters both provided @@ -58,25 +59,16 @@ class Autotuner: if config.pre_hook: config.pre_hook(self.nargs) self.hook(args) - self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) + self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) return do_bench(kernel_call) - def __call__(self, *args, **kwargs): + def run(self, *args, **kwargs): self.nargs = dict(zip(self.arg_names, args)) if len(self.configs) > 1: key = tuple([args[i] for i in self.key_idx]) if key not in self.cache: # prune configs - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + pruned_configs = self.prune_configs(kwargs) bench_start = time.time() timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} @@ -91,13 +83,41 @@ class Autotuner: self.best_config = config if config.pre_hook is not None: config.pre_hook(self.nargs) - return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, + num_warps=config.num_warps) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None class Config: """ An object that represents a possible kernel configuration for the auto-tuner to try. - :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. :type meta: dict[Str, Any] :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if @@ -129,10 +149,8 @@ class Config: def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): """ Decorator for auto-tuning a :code:`triton.jit`'d function. - .. highlight:: python .. code-block:: python - @triton.autotune(configs=[ triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), @@ -143,12 +161,10 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): @triton.jit def kernel(x_ptr, x_size, **META): BLOCK_SIZE = META['BLOCK_SIZE'] - :note: When all the configurations are evaluated, the kernel will run multiple time. This means that whatever value the kernel updates will be updated multiple times. To avoid this undesired behavior, you can use the `reset_to_zero` argument, which reset the value of the provided tensor to `zero` before running any configuration. - :param configs: a list of :code:`triton.Config` objects :type configs: list[triton.Config] :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. @@ -161,43 +177,39 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): :type reset_to_zero: list[str] """ def decorator(fn): - def wrapper(kernel): - return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by) - - fn.kernel_decorators.append(wrapper) - return fn + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by) return decorator +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + def heuristics(values): """ Decorator for specifying how the values of certain meta-parameters may be computed. This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. - .. highlight:: python .. code-block:: python - @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) @triton.jit def kernel(x_ptr, x_size, **META): BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size - - .param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. each such function takes a list of positional arguments as input. .type values: dict[str, Callable[[list[Any]], Any]] """ def decorator(fn): - def wrapper(kernel): - def fun(*args, **meta): - for v, heur in values.items(): - assert v not in meta - meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta}) - return kernel(*args, **meta) - return fun - - fn.kernel_decorators.append(wrapper) - return fn + return Heuristics(fn, fn.arg_names, values) return decorator diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 7fe270c97..4539d3c08 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from __future__ import annotations, division import ast import functools @@ -6,90 +6,24 @@ import hashlib import inspect import os import subprocess -import tempfile import textwrap -from typing import Any, Dict, List, Optional +from collections import namedtuple import torch import triton -import triton._C.libtriton.triton as _triton -from ..compiler import compile -from ..tools.disasm import extract +from triton.utils import MockTensor try: from torch._C import _cuda_getCurrentRawStream as get_cuda_stream except ImportError: get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream -# ----------------------------------------------------------------------------- -# Binary -# ----------------------------------------------------------------------------- - -VALID_BACKENDS: List[str] = ( - _triton.runtime.backend.CUDA, -) - - -class Binary: - def __init__(self, backend: str, name: str, asm: Dict[str, str], shared_mem: int, num_warps: int): - assert backend in VALID_BACKENDS, "backend should within [%s], but get a \"%s\"" % (', '.join(VALID_BACKENDS), backend) - self.backend = backend - self.name = name - self.asm = asm - self.shared_mem = shared_mem - self.num_warps = num_warps - - -class LoadedBinary: - def __init__(self, device: int, bin: Binary): - module, kernel = _triton.load_binary(bin.backend, - bin.name, - bin.asm, - bin.shared_mem, - device) - self.bin = bin - self.asm = bin.asm - self.sass = '' - self.module = module - self.kernel = kernel - self.device = device - self.shared_mem = bin.shared_mem - - def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): - _triton.runtime.enqueue(self.bin.backend, stream, self.kernel, - grid_0, grid_1, grid_2, - self.bin.num_warps * 32, 1, 1, - args, self.bin.shared_mem) - - def get_sass(self, fun=None): - if self.sass: - return self.sass - fd, path = tempfile.mkstemp() - try: - with open(fd, 'wb') as cubin: - cubin.write(self.asm['cubin']) - self.sass = extract(path, fun) - finally: - os.remove(path) - self.asm['sass'] = self.sass - return self.sass - -# ----------------------------------------------------------------------------- -# Kernel -# ----------------------------------------------------------------------------- - - -class Kernel: - - def __call__(self, *args, grid, num_warps=4, num_stages=3, **kwargs): - raise RuntimeError("Not implemented. Public repo implementation will be rewritten to reduce latency.") - - # ----------------------------------------------------------------------------- # Dependencies Finder # ----------------------------------------------------------------------------- + class DependenciesFinder(ast.NodeVisitor): """ This AST visitor is used to find dependencies of a JITFunction. This can @@ -142,6 +76,8 @@ def version_key(): # frontend with open(__file__, "rb") as f: contents += [hashlib.md5(f.read()).hexdigest()] + with open(triton.compiler.__file__, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] # backend with open(triton._C.libtriton.__file__, "rb") as f: contents += [hashlib.md5(f.read()).hexdigest()] @@ -158,26 +94,213 @@ def version_key(): return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) -class JITFunction: +class KernelInterface: + + def __getitem__(self, grid): + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + def launcher(*args, **kwargs): + return self.run(*args, grid=grid, **kwargs) + return launcher + + +class JITFunction(KernelInterface): cache_hook = None + divisibility = 16 - def __init__(self, fn, version=None, inline=True, do_not_specialize=None): - # information of wrapped function + @staticmethod + def _key_of(arg): + if hasattr(arg, "dtype"): + return arg.dtype + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -2**31 <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**31 <= arg and arg <= 2**32 - 1: + return "u32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return 'fp32' + elif arg is None: + return None + else: + raise TypeError(f'Unsupported type {type(arg)} for {arg}') + + @staticmethod + def _spec_of(arg): + if hasattr(arg, "data_ptr"): + return (arg.data_ptr() % JITFunction.divisibility == 0) + elif isinstance(arg, int): + return (arg % 16 == 0, arg == 1) + return (arg is None, ) + + def _get_config(self, *args): + def is_divisible_by_16(x): + if hasattr(x, "data_ptr"): + return x.data_ptr() % JITFunction.divisibility == 0 + elif isinstance(x, int): + return x % JITFunction.divisibility == 0 + if x is None: + return True + return False + divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize} + equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize} + return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1)) + # return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1) + + @staticmethod + def _type_of(key): + if isinstance(key, (torch.dtype, triton.language.dtype)): + ty = { + torch.bool: 'i1', + torch.float16: 'fp16', + torch.bfloat16: 'bf16', + torch.float32: 'fp32', + torch.float64: 'fp64', + torch.uint8: 'u8', + torch.int8: 'i8', + torch.int16: 'i16', + torch.int32: 'i32', + torch.int64: 'i64', + + triton.language.uint8: 'u8', + triton.language.uint16: 'u16', + triton.language.uint32: 'u32', + triton.language.uint64: 'u64', + triton.language.float8: 'fp8', + }[key] + return f'*{ty}' + if key is None: + return '*i8' + assert isinstance(key, str) + return key + + def _make_signature(self, sig_key): + signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)]) + return signature + + def _make_constants(self, constexpr_key): + constants = {i: k for i, k in zip(self.constexprs, constexpr_key)} + return constants + + def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs): + if JITFunction.cache_hook is None: + return False + name = self.fn.__name__ + module = self.fn.__module__ + arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])]) + repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})" + key = str(key) + + class LegacyCompiler: + def __init__(self, module, name): + self.module = module + self.name = name + pass + + kwargs = dict(signature=signature, device=device, constants=constants, + num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, + configs=configs) + + return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False) + + def _make_launcher(self): + regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs] + constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs] + args = ', '.join(regular_args) + # cache key for regular argument type + sig_keys = ', '.join([f'_key_of({arg})' for arg in regular_args]) + # cache key for constexpr argument values + constexpr_keys = ', '.join(constexpr_args) + # cache key for argument specialization + specializations = [] + for i, arg in enumerate(regular_args): + if i in self.do_not_specialize: + continue + specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") ' + f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) ' + f'else (False,)'] + spec_keys = ', '.join(specializations) + grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) + + src = f""" +def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False): + sig_key = {sig_keys}, + constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()} + spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()} + key = (version_key, sig_key, constexpr_key, spec_key) + if not extern_libs is None: + key = (key, tuple(extern_libs.items())) + assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2" + if callable(grid): + grid = grid({{{grid_args}}}) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + device = torch.cuda.current_device() + torch.cuda.set_device(device) + if stream is None and not warmup: + stream = get_cuda_stream(device) + try: + bin = cache[key] + if not warmup: + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, {args}) + return bin + # kernel not cached -- compile + except KeyError: + # build dict of constant values + args = [{args}] + configs = self._get_config(*args), + constants = self._make_constants(constexpr_key) + constants.update({{i: None for i, arg in enumerate(args) if arg is None}}) + constants.update({{i: 1 for i in configs[0].equal_to_1}}) + # build kernel signature -- doesn't include specialized arguments + all_args = {', '.join([f'{arg}' for arg in self.arg_names])}, + signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }} + # build stub signature -- includes arguments that are specialized + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {i} is not supported") + device = 0 + if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): + bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) + if not warmup: + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, *args) + self.cache[key] = bin + return bin + return None +""" + scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream, + "self": self, "_spec_of": self._spec_of, "_key_of": self._key_of, + "cache": self.cache, "triton": triton, "torch": torch} + exec(src, scope) + return scope[self.fn.__name__] + + def __init__(self, fn, version=None, do_not_specialize=None): self.fn = fn self.module = fn.__module__ + self.version = version + # function signature information signature = inspect.signature(fn) self.arg_names = [v.name for v in signature.parameters.values()] - self.arg_defaults = [v.default for v in signature.parameters.values()] - - self.version = version - self.inline = inline + self.has_defaults = any([v.default != inspect._empty for v in signature.parameters.values()]) + # specialization hints + self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize + self.do_not_specialize = set([self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize]) + # function source code (without decorators) self.src = textwrap.dedent(inspect.getsource(fn)) self.src = self.src[self.src.find("def"):] - self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize - self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize] - # cache for callable driver objects (e.g. CUkernel) - self.bin_cache = dict() + # cache of just-in-time compiled kernels + self.cache = dict() self.hash = None # JITFunction can be instantiated as kernel # when called with a grid using __getitem__ @@ -186,16 +309,17 @@ class JITFunction: # annotations self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()} self.__annotations__ = fn.__annotations__ - # constexprs + # index of constexprs self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()] - # forward docs + # launcher + self.run = self._make_launcher() + # re-use docs of wrapped function self.__doc__ = fn.__doc__ self.__name__ = fn.__name__ self.__globals__ = fn.__globals__ self.__module__ = fn.__module__ @property - @functools.lru_cache() def cache_key(self): # TODO : hash should be attribute of `self` if self.hash is None: @@ -204,9 +328,12 @@ class JITFunction: self.hash = dependencies_finder.ret + version_key() return self.hash + def warmup(self, *args, **kwargs): + return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True) + # we do not parse `src` in the constructor because # the user might want to monkey-patch self.src dynamically. - # Some unit tests do this, for example. + # Our unit tests do this, for example. def parse(self): tree = ast.parse(self.src) assert isinstance(tree, ast.Module) @@ -217,167 +344,21 @@ class JITFunction: def __call__(self, *args, **kwargs): raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") - # - when `.src` attribute is set, cache path needs - # to be reinitialized - # - when kernel decorators change, cached kernel - # needs to be cleared def __setattr__(self, name, value): + # - when kernel decorators change, cached kernel + # needs to be cleared if name == 'kernel_decorators': self.kernel = None super(JITFunction, self).__setattr__(name, value) + # - when `.src` attribute is set, cache path needs + # to be reinitialized if name == 'src': self.hash = None - JITFunction.cache_key.fget.cache_clear() - - def _init_kernel(self): - if self.kernel is None: - self.kernel = Kernel(self) - for decorator in reversed(self.kernel_decorators): - self.kernel = decorator(self.kernel) - return self.kernel - - def __getitem__(self, grid): - """ - A JIT function is launched with: fn[grid](*args, **kwargs). - Hence JITFunction.__getitem__ returns a callable proxy that - memorizes the grid. - """ - class Launcher: - def __init__(self, kernel, grid): - self.kernel = kernel - self.grid = grid - - def __call__(self, *wargs, **kwargs): - return self.kernel(*wargs, **kwargs, grid=self.grid) - - return Launcher(self._init_kernel(), grid) def __repr__(self): return f"JITFunction({self.module}:{self.fn.__name__})" -def pow2_divisor(N): - if N % 16 == 0: - return 16 - if N % 8 == 0: - return 8 - if N % 4 == 0: - return 4 - if N % 2 == 0: - return 2 - return 1 - - -class _KernelCache: - def __init__(self, - fn: JITFunction, - fn_type: str, - constants: Dict[str, Any], - num_warps: int = 4, - num_stages: int = 3): - # hold the arguments for building a kernel - self.fn = fn - self.fn_type = fn_type - self.constants = constants - self.num_warps = num_warps - self.num_stages = num_stages - - # kernel compilation cache - self._binary_cache: Optional[LoadedBinary] = None - - @property - def binary_cache(self): - return self._binary_cache - - def set_binary_cache(self, binary: LoadedBinary): - assert binary - assert not self._binary_cache, "cannot set binary cache duplicately" - self._binary_cache = binary - - -def build_kernel(fn: JITFunction, - fn_type: str, - constants: Dict[str, Any], - num_warps: int = 4, - num_stages: int = 3, - ) -> _KernelCache: - return _KernelCache(fn, fn_type, constants, num_warps, num_stages) - - -torch_dtype_to_bytes = { - torch.int8: 1, - torch.uint8: 1, - - torch.int16: 2, - torch.short: 2, - - torch.int: 4, - torch.int32: 4, - - torch.long: 8, - torch.int64: 8, - - torch.float32: 4, - torch.float: 4, - - torch.float16: 2, - torch.half: 2, - torch.bfloat16: 2, - # free to extend -} - - -def launch_kernel(kernel: _KernelCache, grid, device, *wargs, **kwargs): - def is_tensor(arg): - return hasattr(arg, 'data_ptr') # a torch.tensor - - # prepare function args for compile - kwargs = {kernel.fn.arg_names.index(name): value for name, value in kwargs.items()} - wargs = list(wargs) - for i, pos in enumerate(sorted(kwargs)): - wargs.insert(pos + i, kwargs[pos]) - assert len(wargs) == len(kernel.fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(kernel.fn.arg_names), len(wargs)) - - if not kernel.binary_cache: - # build the kernel cache - backend = _triton.runtime.backend.CUDA - - attributes = dict() - for i, arg in enumerate(wargs): - if i in kernel.fn.do_not_specialize: - continue - if isinstance(arg, int): - attributes[i] = pow2_divisor(arg) - elif is_tensor(arg): - assert arg.dtype in torch_dtype_to_bytes - addr = arg.data_ptr() - range_size = _triton.runtime.get_pointer_range_size(addr) - divisibility = min(pow2_divisor(addr), pow2_divisor(range_size)) // torch_dtype_to_bytes[arg.dtype] - attributes[i] = divisibility - - attributes_ = dict() - for i, value in attributes.items(): - attributes_[kernel.fn.arg_names[i]] = value - - cubin, shem_size, kernel_name = compile(kernel.fn, kernel.fn_type, device=device, constants=kernel.constants, attributes=attributes_, num_warps=kernel.num_warps, num_stages=kernel.num_stages, output="cubin") - assert cubin - assert kernel_name - - max_shared_memory = _triton.runtime.max_shared_memory(backend, device) - assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size) - - asm = dict(cubin=cubin) - binary = Binary(backend, kernel_name, asm, shem_size, kernel.num_warps) - loaded_binary = LoadedBinary(device, binary) - kernel.set_binary_cache(loaded_binary) - - torch.cuda.set_device(device) - stream = get_cuda_stream(device) - - _triton.runtime.launch_binary(kernel.binary_cache, wargs, kernel.fn.do_not_specialize, kernel.fn.arg_names, - stream, kernel.num_warps, kernel.num_stages, grid) - - # ----------------------------------------------------------------------------- # `jit` decorator # ----------------------------------------------------------------------------- @@ -386,16 +367,12 @@ def launch_kernel(kernel: _KernelCache, grid, device, *wargs, **kwargs): def jit(*args, **kwargs): """ Decorator for JIT-compiling a function using the Triton compiler. - :note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method. - :note: This function will be compiled and run on the GPU. It will only have access to: - * python primitives, * objects within the triton.language package, * arguments to this function, * other jit'd functions - :param fn: the function to be jit-compiled :type fn: Callable """ @@ -407,3 +384,32 @@ def jit(*args, **kwargs): def decorator(fn): return JITFunction(fn, **kwargs) return decorator + + +class TensorWrapper: + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.is_cuda = base.is_cuda + self.device = base.device + + def data_ptr(self): + return self.base.data_ptr() + + def __str__(self) -> str: + return f'TensorWrapper[{self.dtype}]({self.base})' + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif isinstance(tensor, torch.Tensor): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f'Cannot reinterpret a {type(tensor)}.') diff --git a/python/triton/utils.py b/python/triton/utils.py index f446dd06a..2ac84d06e 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -19,6 +19,24 @@ def next_power_of_2(n): return n +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + @staticmethod + def wrap_dtype(arg): + if isinstance(arg, torch.dtype): + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + def data_ptr(self): + return 0 # optimistically assumes multiple of 16 + + class TensorWrapper: def __init__(self, base, dtype): self.dtype = dtype