diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 3296883b1..9c74cd896 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -959,7 +959,7 @@ def generate_launcher(identifier, constants, signature): "int64_t": "L", }[ty] - format = "iiiiiKK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) # generate glue code src = f""" @@ -1019,14 +1019,37 @@ static PyObject* launch(PyObject* self, PyObject* args) {{ uint64_t _function; int num_warps; int shared_memory; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *compiled_kernel = NULL; + PyObject *hook_ret = NULL; {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ return NULL; }} + if (launch_enter_hook != Py_None) {{ + PyObject *new_args = PyTuple_Pack(1, compiled_kernel); + hook_ret = PyObject_CallObject(launch_enter_hook, new_args); + Py_DECREF(new_args); + }} + _launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); + if (launch_exit_hook != Py_None) {{ + PyObject *new_args = NULL; + if (hook_ret) {{ + new_args = PyTuple_Pack(2, compiled_kernel, hook_ret); + }} else {{ + new_args = PyTuple_Pack(1, compiled_kernel); + }} + hook_ret = PyObject_CallObject(launch_exit_hook, new_args); + Py_DECREF(new_args); + }} + if (hook_ret) {{ + Py_DECREF(hook_ret); + }} if(PyErr_Occurred()) {{ return NULL; }} @@ -1242,6 +1265,10 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i class CompiledKernel: + # Hooks for external tools to monitor the execution of triton kernels + launch_enter_hook = None + launch_exit_hook = None + def __init__(self, fn_name, so_path, cache_dir, device): # initialize launcher import importlib.util @@ -1267,6 +1294,7 @@ class CompiledKernel: self.asm["ttir"] = f.read() mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device) + self.fn_name = fn_name self.cu_module = mod self.cu_function = func self.n_regs = n_regs @@ -1276,7 +1304,8 @@ class CompiledKernel: def runner(*args, stream=None): if stream is None: stream = torch.cuda.current_stream().cuda_stream - self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args) + self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args) return runner def get_sass(self, fun=None): diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 3cf9f836e..bfa0edcc9 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -109,6 +109,7 @@ class KernelInterface: class JITFunction(KernelInterface): + # Hook for inspecting compiled functions and modules cache_hook = None divisibility = 16 @@ -253,7 +254,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage try: bin = cache[key] if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, {args}) + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args}) return bin # kernel not cached -- compile except KeyError: @@ -274,7 +275,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, *args) + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args) self.cache[key] = bin return bin return None