4 Commits

Author SHA1 Message Date
Jokeren
a601309d87 Merge branch 'master' into keren/improve-hook 2023-01-04 17:40:10 -05:00
Jokeren
ee098d0341 Merge branch 'master' into keren/improve-hook 2022-11-25 15:04:59 -08:00
Jokeren
feef58ee8a Pass fn to CompiliedKernel 2022-11-24 14:22:35 -08:00
Jokeren
baab18e1d1 Improve 2022-10-23 20:32:25 -07:00
2 changed files with 12 additions and 22 deletions

View File

@@ -1077,7 +1077,7 @@ def generate_launcher(constants, signature):
"int64_t": "L",
}[ty]
format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + "O"
# generate glue code
src = f"""
@@ -1138,34 +1138,22 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
PyObject *compiled_kernel = NULL;
PyObject *hook_ret = NULL;
PyObject *constants = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())}, &constants)) {{
return NULL;
}}
if (launch_enter_hook != Py_None) {{
PyObject *new_args = PyTuple_Pack(1, compiled_kernel);
hook_ret = PyObject_CallObject(launch_enter_hook, new_args);
Py_DECREF(new_args);
PyObject_CallObject(launch_enter_hook, args);
}}
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
if (launch_exit_hook != Py_None) {{
PyObject *new_args = NULL;
if (hook_ret) {{
new_args = PyTuple_Pack(2, compiled_kernel, hook_ret);
}} else {{
new_args = PyTuple_Pack(1, compiled_kernel);
}}
hook_ret = PyObject_CallObject(launch_exit_hook, new_args);
Py_DECREF(new_args);
PyObject_CallObject(launch_exit_hook, args);
}}
if (hook_ret) {{
Py_DECREF(hook_ret);
}}
if(PyErr_Occurred()) {{
return NULL;
}}
@@ -1523,7 +1511,7 @@ def compile(fn, **kwargs):
# write-back metadata
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
# return handle to compiled kernel
return CompiledKernel(so_path, metadata, asm)
return CompiledKernel(fn, so_path, metadata, asm)
class CompiledKernel:
@@ -1532,17 +1520,19 @@ class CompiledKernel:
launch_enter_hook = None
launch_exit_hook = None
def __init__(self, so_path, metadata, asm):
def __init__(self, fn, so_path, metadata, asm):
# initialize launcher
import importlib.util
spec = importlib.util.spec_from_file_location("launcher", so_path)
mod = importlib.util.module_from_spec(spec)
self.fn = fn
spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch")
# initialize metadata
self.shared = metadata["shared"]
self.num_warps = metadata["num_warps"]
self.num_stages = metadata["num_stages"]
self.constexpr = metadata["constexpr"]
# initialize asm dict
self.asm = asm
# binaries are lazily initialized
@@ -1577,7 +1567,7 @@ class CompiledKernel:
if stream is None:
stream = torch.cuda.current_stream().cuda_stream
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args)
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args, self.constexpr)
return runner
def get_sass(self, fun=None):

View File

@@ -260,7 +260,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
try:
bin = cache[device][key]
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args}, constexpr_key)
return bin
# kernel not cached -- compile
except KeyError:
@@ -280,7 +280,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args, constexpr_key)
self.cache[device][key] = bin
return bin
return None