Pass fn to CompiliedKernel

This commit is contained in:
Jokeren
2022-11-24 14:22:35 -08:00
parent baab18e1d1
commit feef58ee8a

View File

@@ -1248,7 +1248,7 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), constexpr=tu
if warm_cache_only:
return # load_binary() requires a valid cuda context
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir, device)
return CompiledKernel(fn, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir, device)
class CompiledKernel:
@@ -1257,11 +1257,12 @@ class CompiledKernel:
launch_enter_hook = None
launch_exit_hook = None
def __init__(self, fn_name, so_path, cache_dir, device):
def __init__(self, fn, so_path, cache_dir, device):
# initialize launcher
import importlib.util
spec = importlib.util.spec_from_file_location("launcher", so_path)
mod = importlib.util.module_from_spec(spec)
fn_name = fn.__name__
spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch")
# initialize metadata
@@ -1283,7 +1284,7 @@ class CompiledKernel:
self.asm["ttir"] = f.read()
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
self.fn_name = fn_name
self.fn = fn
self.cu_module = mod
self.cu_function = func
self.n_regs = n_regs