[RUNTIME] Add callback functions for external tools (#738)

This commit is contained in:
Keren Zhou
2022-10-05 14:46:55 -07:00
committed by GitHub
parent bdfdb9a1d2
commit 11345e9b74
2 changed files with 35 additions and 5 deletions

View File

@@ -959,7 +959,7 @@ def generate_launcher(identifier, constants, signature):
"int64_t": "L", "int64_t": "L",
}[ty] }[ty]
format = "iiiiiKK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
# generate glue code # generate glue code
src = f""" src = f"""
@@ -1019,14 +1019,37 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
uint64_t _function; uint64_t _function;
int num_warps; int num_warps;
int shared_memory; int shared_memory;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
PyObject *compiled_kernel = NULL;
PyObject *hook_ret = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
return NULL; return NULL;
}} }}
if (launch_enter_hook != Py_None) {{
PyObject *new_args = PyTuple_Pack(1, compiled_kernel);
hook_ret = PyObject_CallObject(launch_enter_hook, new_args);
Py_DECREF(new_args);
}}
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); _launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
if (launch_exit_hook != Py_None) {{
PyObject *new_args = NULL;
if (hook_ret) {{
new_args = PyTuple_Pack(2, compiled_kernel, hook_ret);
}} else {{
new_args = PyTuple_Pack(1, compiled_kernel);
}}
hook_ret = PyObject_CallObject(launch_exit_hook, new_args);
Py_DECREF(new_args);
}}
if (hook_ret) {{
Py_DECREF(hook_ret);
}}
if(PyErr_Occurred()) {{ if(PyErr_Occurred()) {{
return NULL; return NULL;
}} }}
@@ -1242,6 +1265,10 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
class CompiledKernel: class CompiledKernel:
# Hooks for external tools to monitor the execution of triton kernels
launch_enter_hook = None
launch_exit_hook = None
def __init__(self, fn_name, so_path, cache_dir, device): def __init__(self, fn_name, so_path, cache_dir, device):
# initialize launcher # initialize launcher
import importlib.util import importlib.util
@@ -1267,6 +1294,7 @@ class CompiledKernel:
self.asm["ttir"] = f.read() self.asm["ttir"] = f.read()
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device) mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
self.fn_name = fn_name
self.cu_module = mod self.cu_module = mod
self.cu_function = func self.cu_function = func
self.n_regs = n_regs self.n_regs = n_regs
@@ -1276,7 +1304,8 @@ class CompiledKernel:
def runner(*args, stream=None): def runner(*args, stream=None):
if stream is None: if stream is None:
stream = torch.cuda.current_stream().cuda_stream stream = torch.cuda.current_stream().cuda_stream
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args) self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args)
return runner return runner
def get_sass(self, fun=None): def get_sass(self, fun=None):

View File

@@ -109,6 +109,7 @@ class KernelInterface:
class JITFunction(KernelInterface): class JITFunction(KernelInterface):
# Hook for inspecting compiled functions and modules
cache_hook = None cache_hook = None
divisibility = 16 divisibility = 16
@@ -253,7 +254,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
try: try:
bin = cache[key] bin = cache[key]
if not warmup: if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, {args}) bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
return bin return bin
# kernel not cached -- compile # kernel not cached -- compile
except KeyError: except KeyError:
@@ -274,7 +275,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
if not warmup: if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, *args) bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
self.cache[key] = bin self.cache[key] = bin
return bin return bin
return None return None