diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 4187c76f1..a4d90f96a 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1248,7 +1248,7 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), constexpr=tu if warm_cache_only: return # load_binary() requires a valid cuda context - return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir, device) + return CompiledKernel(fn, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir, device) class CompiledKernel: @@ -1257,11 +1257,12 @@ class CompiledKernel: launch_enter_hook = None launch_exit_hook = None - def __init__(self, fn_name, so_path, cache_dir, device): + def __init__(self, fn, so_path, cache_dir, device): # initialize launcher import importlib.util spec = importlib.util.spec_from_file_location("launcher", so_path) mod = importlib.util.module_from_spec(spec) + fn_name = fn.__name__ spec.loader.exec_module(mod) self.c_wrapper = getattr(mod, "launch") # initialize metadata @@ -1283,7 +1284,7 @@ class CompiledKernel: self.asm["ttir"] = f.read() mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device) - self.fn_name = fn_name + self.fn = fn self.cu_module = mod self.cu_function = func self.n_regs = n_regs