[FRONTEND] Backport new runtime from master (#706)

This PR merges the new runtime back into the `triton-mlir` branch. This
adds caching and just-in-time compilation functionality to the
triton-mlir project, and paves the way for re-using tests from the
master branch.
This commit is contained in:
Philippe Tillet
2022-09-23 16:09:43 -07:00
committed by GitHub
parent ecd1bc33df
commit 22ec22c257
13 changed files with 790 additions and 419 deletions

View File

@@ -1,3 +1,5 @@
import torch
import triton
import triton.language as tl
@@ -7,4 +9,5 @@ def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
pass
ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir")
X = torch.randn(1, device="cuda")
pgm = kernel[(1,)](X, 1, 1, BLOCK=1024)

View File

@@ -544,50 +544,6 @@ void init_triton_runtime(py::module &&m) {
/*****************************************************************************/
typedef std::map<std::string, py::object> asm_map_t;
// ---------------------------------------
// Load provided assembly code into driver
// ---------------------------------------
// CUDA
std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string &name,
asm_map_t &asm_map,
size_t n_shared_bytes,
uint64_t dev) {
// load assembly
std::string assembly;
if (asm_map.find("cubin") != asm_map.end())
assembly = py::cast<std::string>(asm_map["cubin"]);
else
assembly = py::cast<std::string>(asm_map["ptx"]);
// create driver handles
CUfunction fun;
CUmodule mod;
drv::dispatch::cuModuleLoadData(&mod, assembly.c_str());
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
// set dynamic shared memory if necessary
int shared_optin;
drv::dispatch::cuDeviceGetAttribute(
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
dev);
if (n_shared_bytes > 49152 && shared_optin > 49152) {
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
int shared_total, shared_static;
int n_spills, n_reg;
drv::dispatch::cuDeviceGetAttribute(
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
dev);
drv::dispatch::cuFuncGetAttribute(&shared_static,
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
drv::dispatch::cuFuncGetAttribute(&n_spills,
CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
drv::dispatch::cuFuncSetAttribute(
fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static);
}
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
}
/*****************************************************************************/
/* Python bindings for triton::ir */
/*****************************************************************************/
@@ -1728,11 +1684,41 @@ void init_triton_translation(py::module &m) {
m.def(
"load_binary",
[](backend_t backend, const std::string &name, asm_map_t &asm_map,
size_t n_shared_bytes, uint64_t dev) {
[](const std::string &name, const std::string &data,
size_t n_shared_bytes, uint64_t device) {
py::gil_scoped_release allow_threads;
assert(backend == CUDA); // Only CUDA is supported now.
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
// create driver handles
CUfunction fun;
CUmodule mod;
drv::dispatch::cuModuleLoadData(&mod, data.c_str());
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
// get allocated registers and spilled registers from the function
int n_regs = 0;
int n_spills = 0;
drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS,
fun);
drv::dispatch::cuFuncGetAttribute(
&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
n_spills /= 4;
// set dynamic shared memory if necessary
int shared_optin;
drv::dispatch::cuDeviceGetAttribute(
&shared_optin,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device);
if (n_shared_bytes > 49152 && shared_optin > 49152) {
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
int shared_total, shared_static;
drv::dispatch::cuDeviceGetAttribute(
&shared_total,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device);
drv::dispatch::cuFuncGetAttribute(
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
drv::dispatch::cuFuncSetAttribute(
fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static);
}
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs,
(uint64_t)n_spills);
},
py::return_value_policy::take_ownership);
}

View File

@@ -2,8 +2,9 @@ import triton
import triton.language as tl
# TODO: function with no arguments don't work
@triton.jit
def cast_check():
def cast_check(X):
zero_0d = tl.zeros([], dtype=tl.float32)
zero_1d = tl.zeros([2], dtype=tl.float32)
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
@@ -48,9 +49,9 @@ def cast_check():
def test_cast_check():
kernel = triton.compile(cast_check,
signature="",
device=0,
output="ttir")
kernel = triton.compiler._compile(cast_check,
signature="*fp32",
device=0,
output="ttgir")
assert (kernel)
# TODO: Check types of the results

View File

@@ -2,7 +2,6 @@ import torch
import triton
import triton.language as tl
import triton.runtime as runtime
# trigger the torch.device implicitly to ensure cuda context initialization
torch.zeros([10], device=torch.device('cuda'))
@@ -16,30 +15,18 @@ def empty_kernel(X, stride_xm, BLOCK: tl.constexpr):
def test_empty_kernel_cubin_compile():
device = torch.cuda.current_device()
cubin = triton.compile(empty_kernel,
"*fp32,i32,i32",
device=device,
constants={"BLOCK": 256},
output="cubin")
kernel = triton.compile(empty_kernel,
"*fp32,i32,i32",
device=device,
constants={"BLOCK": 256})
print('cubin size:', len(cubin))
assert len(cubin) > 0
assert len(kernel.asm["cubin"]) > 0
def test_empty_kernel_launch():
device = torch.cuda.current_device()
binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32",
constants={"BLOCK": 256},
num_warps=4,
num_stages=3)
grid = lambda META: (
triton.cdiv(1024, META['BLOCK']) * triton.cdiv(1024, META['BLOCK']),
)
A = torch.zeros([1024], device="cuda")
runtime.launch_kernel(kernel=binary,
grid=grid,
device=device,
X=A,
stride_xm=256,
BLOCK=tl.constexpr(256))
empty_kernel[grid](X=A, stride_xm=256, BLOCK=256)

View File

@@ -23,11 +23,11 @@ def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr):
def test_empty_kernel_cubin_compile():
kernel = triton.compile(math_kernel,
"*fp32,*fp32,*fp32,*fp32,i32",
device=0,
constants={"BLOCK_SIZE": 256},
output="ttgir") # "cubin"
kernel = triton.compiler._compile(math_kernel,
"*fp32,*fp32,*fp32,*fp32,i32",
device=0,
constants={"BLOCK_SIZE": 256},
output="ttgir") # "cubin"
assert kernel
# TODO: Check if the values are correct.
# TODO: Cover all the math operators

View File

@@ -4,7 +4,6 @@ from torch.testing import assert_allclose
import triton
import triton.language as tl
import triton.runtime as runtime
@triton.jit
@@ -40,29 +39,9 @@ def kernel(x_ptr, stride_xm,
[2, 128, 64]
])
def test_convert_layout_impl(NUM_WARPS, SIZE_M, SIZE_N):
# TODO: this is to initialize the cuda context since it is not properly
# dealed with in the existing runtime, remove this when the runtime
# is updated
torch.zeros([10], device=torch.device('cuda'))
device = torch.cuda.current_device()
binary = runtime.build_kernel(kernel,
"*fp32,i32,*fp32,i32",
constants={"SIZE_M": SIZE_M,
"SIZE_N": SIZE_N},
num_warps=NUM_WARPS,
num_stages=3)
grid = lambda META: (1, )
x = torch.randn((SIZE_M, SIZE_N), device='cuda', dtype=torch.float32)
z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype)
runtime.launch_kernel(kernel=binary,
device=device,
grid=grid,
x_ptr=x,
stride_xm=x.stride(0),
z_ptr=z,
stride_zn=z.stride(0),
SIZE_M=tl.constexpr(SIZE_M),
SIZE_N=tl.constexpr(SIZE_N))
kernel[grid](x_ptr=x, stride_xm=x.stride(0), z_ptr=z, stride_zn=z.stride(0), SIZE_M=SIZE_M, SIZE_N=SIZE_N, num_warps=NUM_WARPS)
golden_z = torch.t(x)
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)

View File

@@ -3,7 +3,6 @@ from torch.testing import assert_allclose
import triton
import triton.language as tl
import triton.runtime as runtime
def vecadd_no_scf_tester(num_warps, block_size):
@@ -22,27 +21,13 @@ def vecadd_no_scf_tester(num_warps, block_size):
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
torch.zeros([10], device=torch.device('cuda'))
device = torch.cuda.current_device()
binary = runtime.build_kernel(kernel, "*fp32,*fp32,*fp32,i32",
constants={"BLOCK_SIZE_N": block_size},
num_warps=num_warps,
num_stages=3)
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
assert x.shape.numel() % block_size == 0, "Only test load without mask here"
grid = lambda EA: (x.shape.numel() // block_size,)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE_N=block_size, num_warps=num_warps)
runtime.launch_kernel(kernel=binary,
grid=grid,
device=device,
x_ptr=x,
y_ptr=y,
z_ptr=z,
BLOCK_SIZE_N=tl.constexpr(block_size))
golden_z = x + y
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)

View File

@@ -7,8 +7,9 @@ __version__ = '2.0.0'
import torch
# submodules
from .utils import *
from .runtime import jit, Config, autotune, heuristics
from .compiler import compile
from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface
from .runtime.jit import jit
from .compiler import compile, CompilationError
from . import language
from . import testing
from . import ops

View File

@@ -1,10 +1,26 @@
from __future__ import annotations
import ast
import contextlib
import functools
import hashlib
import io
import json
import os
import shutil
import subprocess
import sys
import sysconfig
import tempfile
import warnings
from collections import namedtuple
from sysconfig import get_paths
from typing import Any, Dict, Tuple, Union
import setuptools
import torch
from filelock import FileLock
import triton
import triton._C.libtriton.triton as _triton
@@ -85,7 +101,7 @@ class enter_sub_region:
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
def __init__(self, context, prototype, gscope, attributes, constants, function_name, module=None, is_kernel=False, function_types=dict()):
self.builder = _triton.ir.builder(context)
self.module = self.builder.create_module() if module is None else module
self.function_ret_types = function_types
@@ -94,6 +110,7 @@ class CodeGenerator(ast.NodeVisitor):
self.lscope = dict()
self.attributes = attributes
self.constants = constants
self.function_name = function_name
self.is_kernel = is_kernel
self.last_node = None
self.builtins = {
@@ -194,8 +211,7 @@ class CodeGenerator(ast.NodeVisitor):
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
self.visit(init_node)
# initialize function
fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants)
fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder))
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder))
self.module.push_back(fn)
entry = fn.add_entry_block()
arg_values = []
@@ -206,10 +222,10 @@ class CodeGenerator(ast.NodeVisitor):
if not isinstance(cst, triton.language.constexpr):
cst = triton.language.constexpr(self.constants[i])
arg_values.append(cst)
continue
else:
pass
if i in self.attributes:
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i])
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i][1])
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
idx += 1
@@ -746,21 +762,40 @@ class OutOfResources(Exception):
return (type(self), (self.required, self.limit, self.name))
def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
def kernel_suffix(signature, specialization):
# suffix format:
# <argid><'c' if equal to 1><'d' if divisible by 16>
suffix = ''
for i, _ in enumerate(signature):
suffix += str(i)
if i in specialization.equal_to_1:
suffix += 'c'
if i in specialization.divisible_by_16:
suffix += 'd'
return suffix
# ------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
def make_triton_ir(fn, signature, specialization, constants):
context = _triton.ir.context()
context.load_triton()
# create kernel prototype
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
attributes = {fn.arg_names.index(name): value for name, value in attributes.items()}
if signature.replace(' ', '') != '':
arg_types = signature.replace(' ', '').split(',')
arg_types = [str_to_ty(x) for x in arg_types]
else:
arg_types = []
prototype = triton.language.function_type([], arg_types)
cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i
constants = {cst_key(key): value for key, value in constants.items()}
# visit kernel AST
gscope = fn.__globals__.copy()
generator = CodeGenerator(context, prototype, gscope=gscope, constants=constants, attributes=attributes, is_kernel=True)
function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)])
tys = list(signature.values())
new_constants = {k: True if tys[k] == "i1" else 1 for k in specialization.equal_to_1}
new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16}
all_constants = constants.copy()
all_constants.update(new_constants)
arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants]
prototype = triton.language.function_type([], arg_types)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True)
try:
generator.visit(fn.parse())
except Exception as e:
@@ -769,9 +804,9 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
raise e
raise CompilationError(fn.src, node) from e
ret = generator.module
# module takes ownership of the MLIR context
# module takes ownership of the context
ret.context = context
return ret
return ret, generator
def optimize_triton_ir(mod):
@@ -842,14 +877,21 @@ def ptx_get_kernel_name(ptx: str) -> str:
return line.split()[-1]
def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=dict(), num_warps: int = 4, num_stages: int = 3, output: str = "ttgir") -> Tuple[str, int, str]:
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]:
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# triton-ir
module = make_triton_ir(fn, signature, constants, attributes)
module, _ = make_triton_ir(fn, signature, specialization, constants)
module = optimize_triton_ir(module)
if output == "ttir":
return module.str()
# tritongpu-ir
module = make_tritongpu_ir(module, num_warps)
module = optimize_tritongpu_ir(module, num_stages)
@@ -865,6 +907,357 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=d
cubin = make_cubin(ptx, device)
if output == "cubin":
return cubin, shem_size, kernel_name
return cubin, ptx, shem_size, kernel_name
assert False
# ------------------------------------------------------------------------------
# compiler
# ------------------------------------------------------------------------------
def ty_to_cpp(ty):
if ty[0] == '*':
return "CUdeviceptr"
return {
"i1": "int32_t",
"i8": "int8_t",
"i16": "int16_t",
"i32": "int32_t",
"i64": "int64_t",
"u32": "uint32_t",
"u64": "uint64_t",
"fp32": "float",
}[ty]
def generate_name_initializer(signature):
src = "int i = 0;\n"
tys = signature.split(',')
for i, ty in enumerate(tys):
src
def binary_name_to_header_name(name):
if len(name) > 128:
# avoid filename too long errors (filename limit is 255)
name = "kernel_" + hashlib.sha256(name.encode("utf-8")).hexdigest()
return f"{name}.h"
def generate_launcher(identifier, constants, signature):
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
def _extracted_type(ty):
if ty[0] == '*':
return "PyObject*"
return {
'i1': 'int32_t',
'i32': 'int32_t',
'i64': 'int64_t',
'u32': 'uint32_t',
'u64': 'uint64_t',
'fp32': 'float',
'fp64': 'double',
}[ty]
def format_of(ty):
return {
"PyObject*": "O",
"float": "f",
"double": "d",
"long": "l",
"uint32_t": "I",
"int32_t": "i",
"uint64_t": "K",
"int64_t": "L",
}[ty]
format = "iiiiiKK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
# generate glue code
src = f"""
#include \"cuda.h\"
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line)
{{
if (code != CUDA_SUCCESS)
{{
const char* prefix = "Triton Error [CUDA]: ";
const char* str;
cuGetErrorString(code, &str);
char err[1024] = {{0}};
strcat(err, prefix);
strcat(err, str);
PyErr_SetString(PyExc_RuntimeError, err);
}}
}}
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
if(gridX*gridY*gridZ > 0){{
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
}}
}}
static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
if (PyLong_Check(obj)) {{
return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj);
}}
if (obj == Py_None) {{
return (CUdeviceptr)0;
}}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
}}
return (CUdeviceptr)PyLong_AsUnsignedLongLong(ret);
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return (CUdeviceptr)0;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
int gridX, gridY, gridZ;
uint64_t _stream;
uint64_t _function;
int num_warps;
int shared_memory;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
return NULL;
}}
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
if(PyErr_Occurred()) {{
return NULL;
}}
// return None
Py_INCREF(Py_None);
return Py_None;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"launcher\",
NULL, //documentation
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}}
"""
return src
def default_cache_dir():
return os.path.join(os.environ["HOME"], ".triton", "cache")
class CacheManager:
def __init__(self, key):
self.key = key
self.lock_path = None
# create cache directory if it doesn't exist
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
def _make_path(self, filename):
return os.path.join(self.cache_dir, filename)
def has_file(self, filename):
if not self.cache_dir:
return False
return os.path.exists(self._make_path(filename))
def put(self, data, filename, binary=True):
if not self.cache_dir:
return
assert self.lock_path is not None
filepath = self._make_path(filename)
with FileLock(self.lock_path):
# use tempfile to be robust against program interruptions
mode = "wb" if binary else "w"
with open(filepath + ".tmp", mode) as f:
f.write(data)
os.rename(filepath + ".tmp", filepath)
# utilties for generating and compiling C wrappers
@functools.lru_cache()
def libcuda_dir():
loc = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[-1]
return os.path.dirname(loc)
@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
try:
yield
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr
def _build(name, src, srcdir):
cuda_lib_dir = libcuda_dir()
cu_include_dir = "/usr/local/cuda/include"
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
cc = os.environ.get("CC")
if cc is None:
# TODO: support more things here.
clang = shutil.which("clang")
gcc = shutil.which("gcc")
cc = gcc if gcc is not None else clang
py_include_dir = get_paths()["include"]
ret = subprocess.check_call([cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{cuda_lib_dir}", "-lcuda", "-o", so])
if ret == 0:
return so
# fallback on setuptools
extra_compile_args = []
library_dirs = [cuda_lib_dir]
include_dirs = [srcdir, cu_include_dir]
libraries = ['cuda']
# extra arguments
extra_link_args = []
# create extension module
ext = setuptools.Extension(
name=name,
language='c',
sources=[src],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args + ['-O3'],
extra_link_args=extra_link_args,
library_dirs=library_dirs,
libraries=libraries,
)
# build extension module
args = ['build_ext']
args.append('--build-temp=' + srcdir)
args.append('--build-lib=' + srcdir)
args.append('-q')
args = dict(
name=name,
ext_modules=[ext],
script_args=args,
)
with quiet():
setuptools.setup(**args)
return so
def make_so_cache_key(signature, constants):
# Get unique key for the compiled code
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
key = f"{''.join(signature.values())}{constants}"
key = hashlib.md5(key.encode("utf-8")).hexdigest()
return key
def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_stages):
# Get unique key for the compiled code
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
configs_key = [get_conf_key(conf) for conf in configs]
key = f"{fn_hash}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
key = hashlib.md5(key.encode("utf-8")).hexdigest()
return key
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
# we get the kernel, i.e. the first function generated in the module
if configs is None:
configs = [instance_descriptor()]
assert len(configs) == 1
# cache manager
name = fn.__name__
# name of files that are cached
so_cache_key = make_so_cache_key(signature, constants)
so_cache_manager = CacheManager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists
if not so_cache_manager.has_file(so_name):
with tempfile.TemporaryDirectory() as tmpdir:
src = generate_launcher(name, constants, signature)
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build(fn.__name__, src_path, tmpdir)
with open(so, "rb") as f:
so_cache_manager.put(f.read(), so_name, binary=True)
# retrieve cached shared object if it exists
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
fn_cache_manager = CacheManager(fn_cache_key)
ptx_name = f"{name}.ptx"
cubin_name = f"{name}.cubin"
data_name = f"{name}.json"
if not fn_cache_manager.has_file(cubin_name) or \
not fn_cache_manager.has_file(data_name) or \
not fn_cache_manager.has_file(ptx_name):
cubin, ptx, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin")
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
fn_cache_manager.put(cubin, cubin_name)
fn_cache_manager.put(ptx, ptx_name, binary=False)
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir)
class CompiledKernel:
def __init__(self, fn_name, so_path, cache_dir):
# initialize launcher
import importlib.util
spec = importlib.util.spec_from_file_location("launcher", so_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch")
# initialize metadata
with open(os.path.join(cache_dir, f"{fn_name}.json")) as f:
metadata = json.load(f)
self.shared = metadata["shared"]
self.num_warps = metadata["num_warps"]
self.num_stages = metadata["num_stages"]
# initialize asm dict
self.asm = dict()
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
self.asm["cubin"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
self.asm["ptx"] = f.read()
device = torch.cuda.current_device()
mod, func, n_regs, n_spills = _triton.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
self.cu_module = mod
self.cu_function = func
def __getitem__(self, grid):
def runner(*args, stream=None):
if stream is None:
stream = torch.cuda.current_stream().cuda_stream
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
return

View File

@@ -1,2 +1,2 @@
from .autotuner import Config, autotune, heuristics # noqa: F401
from .jit import JITFunction, build_kernel, jit, launch_kernel # noqa: F401
from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401
from .jit import JITFunction, KernelInterface, version_key # noqa: F401

View File

@@ -5,10 +5,11 @@ import time
from typing import Dict
from ..testing import do_bench
from .jit import KernelInterface
class Autotuner:
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
class Autotuner(KernelInterface):
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
'''
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
@@ -21,7 +22,6 @@ class Autotuner:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.cache = dict()
self.kernel = kernel
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
@@ -41,6 +41,7 @@ class Autotuner:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
@@ -58,25 +59,16 @@ class Autotuner:
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return do_bench(kernel_call)
def __call__(self, *args, **kwargs):
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple([args[i] for i in self.key_idx])
if key not in self.cache:
# prune configs
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs}
@@ -91,13 +83,41 @@ class Autotuner:
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
def prune_configs(self, kwargs):
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
num_warps=config.num_warps)
for config in pruned_configs
}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs
def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
class Config:
"""
An object that represents a possible kernel configuration for the auto-tuner to try.
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
:type meta: dict[Str, Any]
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
@@ -129,10 +149,8 @@ class Config:
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
@@ -143,12 +161,10 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
@@ -161,43 +177,39 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
:type reset_to_zero: list[str]
"""
def decorator(fn):
def wrapper(kernel):
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
fn.kernel_decorators.append(wrapper)
return fn
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
return decorator
class Heuristics(KernelInterface):
def __init__(self, fn, arg_names, values) -> None:
self.fn = fn
self.values = values
self.arg_names = arg_names
def run(self, *args, **kwargs):
for v, heur in self.values.items():
kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
return self.fn.run(*args, **kwargs)
def heuristics(values):
"""
Decorator for specifying how the values of certain meta-parameters may be computed.
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
.. highlight:: python
.. code-block:: python
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
each such function takes a list of positional arguments as input.
.type values: dict[str, Callable[[list[Any]], Any]]
"""
def decorator(fn):
def wrapper(kernel):
def fun(*args, **meta):
for v, heur in values.items():
assert v not in meta
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
return kernel(*args, **meta)
return fun
fn.kernel_decorators.append(wrapper)
return fn
return Heuristics(fn, fn.arg_names, values)
return decorator

View File

@@ -1,4 +1,4 @@
from __future__ import annotations
from __future__ import annotations, division
import ast
import functools
@@ -6,90 +6,24 @@ import hashlib
import inspect
import os
import subprocess
import tempfile
import textwrap
from typing import Any, Dict, List, Optional
from collections import namedtuple
import torch
import triton
import triton._C.libtriton.triton as _triton
from ..compiler import compile
from ..tools.disasm import extract
from triton.utils import MockTensor
try:
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
except ImportError:
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
# -----------------------------------------------------------------------------
# Binary
# -----------------------------------------------------------------------------
VALID_BACKENDS: List[str] = (
_triton.runtime.backend.CUDA,
)
class Binary:
def __init__(self, backend: str, name: str, asm: Dict[str, str], shared_mem: int, num_warps: int):
assert backend in VALID_BACKENDS, "backend should within [%s], but get a \"%s\"" % (', '.join(VALID_BACKENDS), backend)
self.backend = backend
self.name = name
self.asm = asm
self.shared_mem = shared_mem
self.num_warps = num_warps
class LoadedBinary:
def __init__(self, device: int, bin: Binary):
module, kernel = _triton.load_binary(bin.backend,
bin.name,
bin.asm,
bin.shared_mem,
device)
self.bin = bin
self.asm = bin.asm
self.sass = ''
self.module = module
self.kernel = kernel
self.device = device
self.shared_mem = bin.shared_mem
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
grid_0, grid_1, grid_2,
self.bin.num_warps * 32, 1, 1,
args, self.bin.shared_mem)
def get_sass(self, fun=None):
if self.sass:
return self.sass
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(self.asm['cubin'])
self.sass = extract(path, fun)
finally:
os.remove(path)
self.asm['sass'] = self.sass
return self.sass
# -----------------------------------------------------------------------------
# Kernel
# -----------------------------------------------------------------------------
class Kernel:
def __call__(self, *args, grid, num_warps=4, num_stages=3, **kwargs):
raise RuntimeError("Not implemented. Public repo implementation will be rewritten to reduce latency.")
# -----------------------------------------------------------------------------
# Dependencies Finder
# -----------------------------------------------------------------------------
class DependenciesFinder(ast.NodeVisitor):
"""
This AST visitor is used to find dependencies of a JITFunction. This can
@@ -142,6 +76,8 @@ def version_key():
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
with open(triton.compiler.__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# backend
with open(triton._C.libtriton.__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
@@ -158,26 +94,213 @@ def version_key():
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
class JITFunction:
class KernelInterface:
def __getitem__(self, grid):
"""
A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
def launcher(*args, **kwargs):
return self.run(*args, grid=grid, **kwargs)
return launcher
class JITFunction(KernelInterface):
cache_hook = None
divisibility = 16
def __init__(self, fn, version=None, inline=True, do_not_specialize=None):
# information of wrapped function
@staticmethod
def _key_of(arg):
if hasattr(arg, "dtype"):
return arg.dtype
elif isinstance(arg, bool):
return "i1"
elif isinstance(arg, int):
if -2**31 <= arg and arg <= 2**31 - 1:
return "i32"
elif 2**31 <= arg and arg <= 2**32 - 1:
return "u32"
elif 2**63 <= arg and arg <= 2**64 - 1:
return "u64"
else:
return "i64"
elif isinstance(arg, float):
return 'fp32'
elif arg is None:
return None
else:
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
@staticmethod
def _spec_of(arg):
if hasattr(arg, "data_ptr"):
return (arg.data_ptr() % JITFunction.divisibility == 0)
elif isinstance(arg, int):
return (arg % 16 == 0, arg == 1)
return (arg is None, )
def _get_config(self, *args):
def is_divisible_by_16(x):
if hasattr(x, "data_ptr"):
return x.data_ptr() % JITFunction.divisibility == 0
elif isinstance(x, int):
return x % JITFunction.divisibility == 0
if x is None:
return True
return False
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
@staticmethod
def _type_of(key):
if isinstance(key, (torch.dtype, triton.language.dtype)):
ty = {
torch.bool: 'i1',
torch.float16: 'fp16',
torch.bfloat16: 'bf16',
torch.float32: 'fp32',
torch.float64: 'fp64',
torch.uint8: 'u8',
torch.int8: 'i8',
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
triton.language.uint8: 'u8',
triton.language.uint16: 'u16',
triton.language.uint32: 'u32',
triton.language.uint64: 'u64',
triton.language.float8: 'fp8',
}[key]
return f'*{ty}'
if key is None:
return '*i8'
assert isinstance(key, str)
return key
def _make_signature(self, sig_key):
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
return signature
def _make_constants(self, constexpr_key):
constants = {i: k for i, k in zip(self.constexprs, constexpr_key)}
return constants
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
key = str(key)
class LegacyCompiler:
def __init__(self, module, name):
self.module = module
self.name = name
pass
kwargs = dict(signature=signature, device=device, constants=constants,
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
configs=configs)
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
def _make_launcher(self):
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
args = ', '.join(regular_args)
# cache key for regular argument type
sig_keys = ', '.join([f'_key_of({arg})' for arg in regular_args])
# cache key for constexpr argument values
constexpr_keys = ', '.join(constexpr_args)
# cache key for argument specialization
specializations = []
for i, arg in enumerate(regular_args):
if i in self.do_not_specialize:
continue
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") '
f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) '
f'else (False,)']
spec_keys = ', '.join(specializations)
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
src = f"""
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False):
sig_key = {sig_keys},
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()}
key = (version_key, sig_key, constexpr_key, spec_key)
if not extern_libs is None:
key = (key, tuple(extern_libs.items()))
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
if callable(grid):
grid = grid({{{grid_args}}})
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
device = torch.cuda.current_device()
torch.cuda.set_device(device)
if stream is None and not warmup:
stream = get_cuda_stream(device)
try:
bin = cache[key]
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, {args})
return bin
# kernel not cached -- compile
except KeyError:
# build dict of constant values
args = [{args}]
configs = self._get_config(*args),
constants = self._make_constants(constexpr_key)
constants.update({{i: None for i, arg in enumerate(args) if arg is None}})
constants.update({{i: 1 for i in configs[0].equal_to_1}})
# build kernel signature -- doesn't include specialized arguments
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
device = 0
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, *args)
self.cache[key] = bin
return bin
return None
"""
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
"cache": self.cache, "triton": triton, "torch": torch}
exec(src, scope)
return scope[self.fn.__name__]
def __init__(self, fn, version=None, do_not_specialize=None):
self.fn = fn
self.module = fn.__module__
self.version = version
# function signature information
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.arg_defaults = [v.default for v in signature.parameters.values()]
self.version = version
self.inline = inline
self.has_defaults = any([v.default != inspect._empty for v in signature.parameters.values()])
# specialization hints
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = set([self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize])
# function source code (without decorators)
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize]
# cache for callable driver objects (e.g. CUkernel)
self.bin_cache = dict()
# cache of just-in-time compiled kernels
self.cache = dict()
self.hash = None
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
@@ -186,16 +309,17 @@ class JITFunction:
# annotations
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
self.__annotations__ = fn.__annotations__
# constexprs
# index of constexprs
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
# forward docs
# launcher
self.run = self._make_launcher()
# re-use docs of wrapped function
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
self.__globals__ = fn.__globals__
self.__module__ = fn.__module__
@property
@functools.lru_cache()
def cache_key(self):
# TODO : hash should be attribute of `self`
if self.hash is None:
@@ -204,9 +328,12 @@ class JITFunction:
self.hash = dependencies_finder.ret + version_key()
return self.hash
def warmup(self, *args, **kwargs):
return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True)
# we do not parse `src` in the constructor because
# the user might want to monkey-patch self.src dynamically.
# Some unit tests do this, for example.
# Our unit tests do this, for example.
def parse(self):
tree = ast.parse(self.src)
assert isinstance(tree, ast.Module)
@@ -217,167 +344,21 @@ class JITFunction:
def __call__(self, *args, **kwargs):
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
# - when `.src` attribute is set, cache path needs
# to be reinitialized
# - when kernel decorators change, cached kernel
# needs to be cleared
def __setattr__(self, name, value):
# - when kernel decorators change, cached kernel
# needs to be cleared
if name == 'kernel_decorators':
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
# - when `.src` attribute is set, cache path needs
# to be reinitialized
if name == 'src':
self.hash = None
JITFunction.cache_key.fget.cache_clear()
def _init_kernel(self):
if self.kernel is None:
self.kernel = Kernel(self)
for decorator in reversed(self.kernel_decorators):
self.kernel = decorator(self.kernel)
return self.kernel
def __getitem__(self, grid):
"""
A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
class Launcher:
def __init__(self, kernel, grid):
self.kernel = kernel
self.grid = grid
def __call__(self, *wargs, **kwargs):
return self.kernel(*wargs, **kwargs, grid=self.grid)
return Launcher(self._init_kernel(), grid)
def __repr__(self):
return f"JITFunction({self.module}:{self.fn.__name__})"
def pow2_divisor(N):
if N % 16 == 0:
return 16
if N % 8 == 0:
return 8
if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1
class _KernelCache:
def __init__(self,
fn: JITFunction,
fn_type: str,
constants: Dict[str, Any],
num_warps: int = 4,
num_stages: int = 3):
# hold the arguments for building a kernel
self.fn = fn
self.fn_type = fn_type
self.constants = constants
self.num_warps = num_warps
self.num_stages = num_stages
# kernel compilation cache
self._binary_cache: Optional[LoadedBinary] = None
@property
def binary_cache(self):
return self._binary_cache
def set_binary_cache(self, binary: LoadedBinary):
assert binary
assert not self._binary_cache, "cannot set binary cache duplicately"
self._binary_cache = binary
def build_kernel(fn: JITFunction,
fn_type: str,
constants: Dict[str, Any],
num_warps: int = 4,
num_stages: int = 3,
) -> _KernelCache:
return _KernelCache(fn, fn_type, constants, num_warps, num_stages)
torch_dtype_to_bytes = {
torch.int8: 1,
torch.uint8: 1,
torch.int16: 2,
torch.short: 2,
torch.int: 4,
torch.int32: 4,
torch.long: 8,
torch.int64: 8,
torch.float32: 4,
torch.float: 4,
torch.float16: 2,
torch.half: 2,
torch.bfloat16: 2,
# free to extend
}
def launch_kernel(kernel: _KernelCache, grid, device, *wargs, **kwargs):
def is_tensor(arg):
return hasattr(arg, 'data_ptr') # a torch.tensor
# prepare function args for compile
kwargs = {kernel.fn.arg_names.index(name): value for name, value in kwargs.items()}
wargs = list(wargs)
for i, pos in enumerate(sorted(kwargs)):
wargs.insert(pos + i, kwargs[pos])
assert len(wargs) == len(kernel.fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(kernel.fn.arg_names), len(wargs))
if not kernel.binary_cache:
# build the kernel cache
backend = _triton.runtime.backend.CUDA
attributes = dict()
for i, arg in enumerate(wargs):
if i in kernel.fn.do_not_specialize:
continue
if isinstance(arg, int):
attributes[i] = pow2_divisor(arg)
elif is_tensor(arg):
assert arg.dtype in torch_dtype_to_bytes
addr = arg.data_ptr()
range_size = _triton.runtime.get_pointer_range_size(addr)
divisibility = min(pow2_divisor(addr), pow2_divisor(range_size)) // torch_dtype_to_bytes[arg.dtype]
attributes[i] = divisibility
attributes_ = dict()
for i, value in attributes.items():
attributes_[kernel.fn.arg_names[i]] = value
cubin, shem_size, kernel_name = compile(kernel.fn, kernel.fn_type, device=device, constants=kernel.constants, attributes=attributes_, num_warps=kernel.num_warps, num_stages=kernel.num_stages, output="cubin")
assert cubin
assert kernel_name
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size)
asm = dict(cubin=cubin)
binary = Binary(backend, kernel_name, asm, shem_size, kernel.num_warps)
loaded_binary = LoadedBinary(device, binary)
kernel.set_binary_cache(loaded_binary)
torch.cuda.set_device(device)
stream = get_cuda_stream(device)
_triton.runtime.launch_binary(kernel.binary_cache, wargs, kernel.fn.do_not_specialize, kernel.fn.arg_names,
stream, kernel.num_warps, kernel.num_stages, grid)
# -----------------------------------------------------------------------------
# `jit` decorator
# -----------------------------------------------------------------------------
@@ -386,16 +367,12 @@ def launch_kernel(kernel: _KernelCache, grid, device, *wargs, **kwargs):
def jit(*args, **kwargs):
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* objects within the triton.language package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
@@ -407,3 +384,32 @@ def jit(*args, **kwargs):
def decorator(fn):
return JITFunction(fn, **kwargs)
return decorator
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base
self.is_cuda = base.is_cuda
self.device = base.device
def data_ptr(self):
return self.base.data_ptr()
def __str__(self) -> str:
return f'TensorWrapper[{self.dtype}]({self.base})'
def reinterpret(tensor, dtype):
if isinstance(tensor, TensorWrapper):
if dtype == tensor.base.dtype:
# Reinterpreting to the original interpretation; return the base.
return tensor.base
else:
# Reinterpreting a wrapped tensor to a different type.
return TensorWrapper(tensor.base, dtype)
elif isinstance(tensor, torch.Tensor):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')

View File

@@ -19,6 +19,24 @@ def next_power_of_2(n):
return n
class MockTensor:
"""
Can be used in place of real tensors when calling:
kernel.warmup(MockTensor(torch.float32), ...)
"""
@staticmethod
def wrap_dtype(arg):
if isinstance(arg, torch.dtype):
return MockTensor(arg)
return arg
def __init__(self, dtype):
self.dtype = dtype
def data_ptr(self):
return 0 # optimistically assumes multiple of 16
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype