From baab18e1d1e5ca949cc855d1dcb6fe4bb59d8ece Mon Sep 17 00:00:00 2001 From: Jokeren Date: Sun, 23 Oct 2022 20:32:25 -0700 Subject: [PATCH] Improve --- python/triton/compiler.py | 29 +++++++++-------------------- python/triton/runtime/jit.py | 6 +++--- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 1332f2c76..4187c76f1 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -959,7 +959,7 @@ def generate_launcher(identifier, constants, signature): "int64_t": "L", }[ty] - format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + "O" # generate glue code src = f""" @@ -1022,34 +1022,22 @@ static PyObject* launch(PyObject* self, PyObject* args) {{ PyObject *launch_enter_hook = NULL; PyObject *launch_exit_hook = NULL; PyObject *compiled_kernel = NULL; - PyObject *hook_ret = NULL; + PyObject *constants = NULL; {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())}, &constants)) {{ return NULL; }} if (launch_enter_hook != Py_None) {{ - PyObject *new_args = PyTuple_Pack(1, compiled_kernel); - hook_ret = PyObject_CallObject(launch_enter_hook, new_args); - Py_DECREF(new_args); + PyObject_CallObject(launch_enter_hook, args); }} _launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); if (launch_exit_hook != Py_None) {{ - PyObject *new_args = NULL; - if (hook_ret) {{ - new_args = PyTuple_Pack(2, compiled_kernel, hook_ret); - }} else {{ - new_args = PyTuple_Pack(1, compiled_kernel); - }} - hook_ret = PyObject_CallObject(launch_exit_hook, new_args); - Py_DECREF(new_args); + PyObject_CallObject(launch_exit_hook, args); }} - if (hook_ret) {{ - Py_DECREF(hook_ret); - }} if(PyErr_Occurred()) {{ return NULL; }} @@ -1214,7 +1202,7 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta return key -def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, +def compile(fn, signature: str, device: int = -1, constants=dict(), constexpr=tuple(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None, cc=0, warm_cache_only=False): # we get the kernel, i.e. the first function generated in the module assert len(configs) == 1 @@ -1250,7 +1238,7 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i not fn_cache_manager.has_file(llir_name): asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin", cc) - metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages} + metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages, "constexpr": constexpr} fn_cache_manager.put(asm["cubin"], cubin_name) fn_cache_manager.put(asm["ptx"], ptx_name, binary=False) fn_cache_manager.put(asm["ttir"], ttir_name, binary=False) @@ -1282,6 +1270,7 @@ class CompiledKernel: self.shared = metadata["shared"] self.num_warps = metadata["num_warps"] self.num_stages = metadata["num_stages"] + self.constexpr = metadata["constexpr"] # initialize asm dict self.asm = dict() with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f: @@ -1305,7 +1294,7 @@ class CompiledKernel: if stream is None: stream = torch.cuda.current_stream().cuda_stream self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, - CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args) + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args, self.constexpr) return runner def get_sass(self, fun=None): diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 89ad3e2ca..f5a720edb 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -254,7 +254,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage try: bin = cache[device][key] if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args}) + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args}, constexpr_key) return bin # kernel not cached -- compile except KeyError: @@ -272,9 +272,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage if callable(arg): raise TypeError(f"Callable constexpr at index {i} is not supported") if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): - bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) + bin = triton.compile(self, signature, device, constants, constexpr_key, num_warps, num_stages, extern_libs=extern_libs, configs=configs) if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args) + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args, constexpr_key) self.cache[device][key] = bin return bin return None