[FRONTEND] Backport new runtime from master (#706)

This PR merges the new runtime back into the `triton-mlir` branch. This
adds caching and just-in-time compilation functionality to the
triton-mlir project, and paves the way for re-using tests from the
master branch.
This commit is contained in:
Philippe Tillet
2022-09-23 16:09:43 -07:00
committed by GitHub
parent ecd1bc33df
commit 22ec22c257
13 changed files with 790 additions and 419 deletions

View File

@@ -1,10 +1,26 @@
from __future__ import annotations
import ast
import contextlib
import functools
import hashlib
import io
import json
import os
import shutil
import subprocess
import sys
import sysconfig
import tempfile
import warnings
from collections import namedtuple
from sysconfig import get_paths
from typing import Any, Dict, Tuple, Union
import setuptools
import torch
from filelock import FileLock
import triton
import triton._C.libtriton.triton as _triton
@@ -85,7 +101,7 @@ class enter_sub_region:
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
def __init__(self, context, prototype, gscope, attributes, constants, function_name, module=None, is_kernel=False, function_types=dict()):
self.builder = _triton.ir.builder(context)
self.module = self.builder.create_module() if module is None else module
self.function_ret_types = function_types
@@ -94,6 +110,7 @@ class CodeGenerator(ast.NodeVisitor):
self.lscope = dict()
self.attributes = attributes
self.constants = constants
self.function_name = function_name
self.is_kernel = is_kernel
self.last_node = None
self.builtins = {
@@ -194,8 +211,7 @@ class CodeGenerator(ast.NodeVisitor):
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
self.visit(init_node)
# initialize function
fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants)
fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder))
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder))
self.module.push_back(fn)
entry = fn.add_entry_block()
arg_values = []
@@ -206,10 +222,10 @@ class CodeGenerator(ast.NodeVisitor):
if not isinstance(cst, triton.language.constexpr):
cst = triton.language.constexpr(self.constants[i])
arg_values.append(cst)
continue
else:
pass
if i in self.attributes:
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i])
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i][1])
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
idx += 1
@@ -746,21 +762,40 @@ class OutOfResources(Exception):
return (type(self), (self.required, self.limit, self.name))
def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
def kernel_suffix(signature, specialization):
# suffix format:
# <argid><'c' if equal to 1><'d' if divisible by 16>
suffix = ''
for i, _ in enumerate(signature):
suffix += str(i)
if i in specialization.equal_to_1:
suffix += 'c'
if i in specialization.divisible_by_16:
suffix += 'd'
return suffix
# ------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
def make_triton_ir(fn, signature, specialization, constants):
context = _triton.ir.context()
context.load_triton()
# create kernel prototype
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
attributes = {fn.arg_names.index(name): value for name, value in attributes.items()}
if signature.replace(' ', '') != '':
arg_types = signature.replace(' ', '').split(',')
arg_types = [str_to_ty(x) for x in arg_types]
else:
arg_types = []
prototype = triton.language.function_type([], arg_types)
cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i
constants = {cst_key(key): value for key, value in constants.items()}
# visit kernel AST
gscope = fn.__globals__.copy()
generator = CodeGenerator(context, prototype, gscope=gscope, constants=constants, attributes=attributes, is_kernel=True)
function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)])
tys = list(signature.values())
new_constants = {k: True if tys[k] == "i1" else 1 for k in specialization.equal_to_1}
new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16}
all_constants = constants.copy()
all_constants.update(new_constants)
arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants]
prototype = triton.language.function_type([], arg_types)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True)
try:
generator.visit(fn.parse())
except Exception as e:
@@ -769,9 +804,9 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
raise e
raise CompilationError(fn.src, node) from e
ret = generator.module
# module takes ownership of the MLIR context
# module takes ownership of the context
ret.context = context
return ret
return ret, generator
def optimize_triton_ir(mod):
@@ -842,14 +877,21 @@ def ptx_get_kernel_name(ptx: str) -> str:
return line.split()[-1]
def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=dict(), num_warps: int = 4, num_stages: int = 3, output: str = "ttgir") -> Tuple[str, int, str]:
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]:
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# triton-ir
module = make_triton_ir(fn, signature, constants, attributes)
module, _ = make_triton_ir(fn, signature, specialization, constants)
module = optimize_triton_ir(module)
if output == "ttir":
return module.str()
# tritongpu-ir
module = make_tritongpu_ir(module, num_warps)
module = optimize_tritongpu_ir(module, num_stages)
@@ -865,6 +907,357 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=d
cubin = make_cubin(ptx, device)
if output == "cubin":
return cubin, shem_size, kernel_name
return cubin, ptx, shem_size, kernel_name
assert False
# ------------------------------------------------------------------------------
# compiler
# ------------------------------------------------------------------------------
def ty_to_cpp(ty):
if ty[0] == '*':
return "CUdeviceptr"
return {
"i1": "int32_t",
"i8": "int8_t",
"i16": "int16_t",
"i32": "int32_t",
"i64": "int64_t",
"u32": "uint32_t",
"u64": "uint64_t",
"fp32": "float",
}[ty]
def generate_name_initializer(signature):
src = "int i = 0;\n"
tys = signature.split(',')
for i, ty in enumerate(tys):
src
def binary_name_to_header_name(name):
if len(name) > 128:
# avoid filename too long errors (filename limit is 255)
name = "kernel_" + hashlib.sha256(name.encode("utf-8")).hexdigest()
return f"{name}.h"
def generate_launcher(identifier, constants, signature):
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
def _extracted_type(ty):
if ty[0] == '*':
return "PyObject*"
return {
'i1': 'int32_t',
'i32': 'int32_t',
'i64': 'int64_t',
'u32': 'uint32_t',
'u64': 'uint64_t',
'fp32': 'float',
'fp64': 'double',
}[ty]
def format_of(ty):
return {
"PyObject*": "O",
"float": "f",
"double": "d",
"long": "l",
"uint32_t": "I",
"int32_t": "i",
"uint64_t": "K",
"int64_t": "L",
}[ty]
format = "iiiiiKK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
# generate glue code
src = f"""
#include \"cuda.h\"
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line)
{{
if (code != CUDA_SUCCESS)
{{
const char* prefix = "Triton Error [CUDA]: ";
const char* str;
cuGetErrorString(code, &str);
char err[1024] = {{0}};
strcat(err, prefix);
strcat(err, str);
PyErr_SetString(PyExc_RuntimeError, err);
}}
}}
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
if(gridX*gridY*gridZ > 0){{
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
}}
}}
static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
if (PyLong_Check(obj)) {{
return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj);
}}
if (obj == Py_None) {{
return (CUdeviceptr)0;
}}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
}}
return (CUdeviceptr)PyLong_AsUnsignedLongLong(ret);
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return (CUdeviceptr)0;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
int gridX, gridY, gridZ;
uint64_t _stream;
uint64_t _function;
int num_warps;
int shared_memory;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
return NULL;
}}
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
if(PyErr_Occurred()) {{
return NULL;
}}
// return None
Py_INCREF(Py_None);
return Py_None;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"launcher\",
NULL, //documentation
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}}
"""
return src
def default_cache_dir():
return os.path.join(os.environ["HOME"], ".triton", "cache")
class CacheManager:
def __init__(self, key):
self.key = key
self.lock_path = None
# create cache directory if it doesn't exist
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
def _make_path(self, filename):
return os.path.join(self.cache_dir, filename)
def has_file(self, filename):
if not self.cache_dir:
return False
return os.path.exists(self._make_path(filename))
def put(self, data, filename, binary=True):
if not self.cache_dir:
return
assert self.lock_path is not None
filepath = self._make_path(filename)
with FileLock(self.lock_path):
# use tempfile to be robust against program interruptions
mode = "wb" if binary else "w"
with open(filepath + ".tmp", mode) as f:
f.write(data)
os.rename(filepath + ".tmp", filepath)
# utilties for generating and compiling C wrappers
@functools.lru_cache()
def libcuda_dir():
loc = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[-1]
return os.path.dirname(loc)
@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
try:
yield
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr
def _build(name, src, srcdir):
cuda_lib_dir = libcuda_dir()
cu_include_dir = "/usr/local/cuda/include"
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
cc = os.environ.get("CC")
if cc is None:
# TODO: support more things here.
clang = shutil.which("clang")
gcc = shutil.which("gcc")
cc = gcc if gcc is not None else clang
py_include_dir = get_paths()["include"]
ret = subprocess.check_call([cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{cuda_lib_dir}", "-lcuda", "-o", so])
if ret == 0:
return so
# fallback on setuptools
extra_compile_args = []
library_dirs = [cuda_lib_dir]
include_dirs = [srcdir, cu_include_dir]
libraries = ['cuda']
# extra arguments
extra_link_args = []
# create extension module
ext = setuptools.Extension(
name=name,
language='c',
sources=[src],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args + ['-O3'],
extra_link_args=extra_link_args,
library_dirs=library_dirs,
libraries=libraries,
)
# build extension module
args = ['build_ext']
args.append('--build-temp=' + srcdir)
args.append('--build-lib=' + srcdir)
args.append('-q')
args = dict(
name=name,
ext_modules=[ext],
script_args=args,
)
with quiet():
setuptools.setup(**args)
return so
def make_so_cache_key(signature, constants):
# Get unique key for the compiled code
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
key = f"{''.join(signature.values())}{constants}"
key = hashlib.md5(key.encode("utf-8")).hexdigest()
return key
def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_stages):
# Get unique key for the compiled code
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
configs_key = [get_conf_key(conf) for conf in configs]
key = f"{fn_hash}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
key = hashlib.md5(key.encode("utf-8")).hexdigest()
return key
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
# we get the kernel, i.e. the first function generated in the module
if configs is None:
configs = [instance_descriptor()]
assert len(configs) == 1
# cache manager
name = fn.__name__
# name of files that are cached
so_cache_key = make_so_cache_key(signature, constants)
so_cache_manager = CacheManager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists
if not so_cache_manager.has_file(so_name):
with tempfile.TemporaryDirectory() as tmpdir:
src = generate_launcher(name, constants, signature)
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build(fn.__name__, src_path, tmpdir)
with open(so, "rb") as f:
so_cache_manager.put(f.read(), so_name, binary=True)
# retrieve cached shared object if it exists
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
fn_cache_manager = CacheManager(fn_cache_key)
ptx_name = f"{name}.ptx"
cubin_name = f"{name}.cubin"
data_name = f"{name}.json"
if not fn_cache_manager.has_file(cubin_name) or \
not fn_cache_manager.has_file(data_name) or \
not fn_cache_manager.has_file(ptx_name):
cubin, ptx, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin")
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
fn_cache_manager.put(cubin, cubin_name)
fn_cache_manager.put(ptx, ptx_name, binary=False)
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir)
class CompiledKernel:
def __init__(self, fn_name, so_path, cache_dir):
# initialize launcher
import importlib.util
spec = importlib.util.spec_from_file_location("launcher", so_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch")
# initialize metadata
with open(os.path.join(cache_dir, f"{fn_name}.json")) as f:
metadata = json.load(f)
self.shared = metadata["shared"]
self.num_warps = metadata["num_warps"]
self.num_stages = metadata["num_stages"]
# initialize asm dict
self.asm = dict()
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
self.asm["cubin"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
self.asm["ptx"] = f.read()
device = torch.cuda.current_device()
mod, func, n_regs, n_spills = _triton.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
self.cu_module = mod
self.cu_function = func
def __getitem__(self, grid):
def runner(*args, stream=None):
if stream is None:
stream = torch.cuda.current_stream().cuda_stream
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
return