Pass fn to CompiliedKernel
This commit is contained in:
@@ -1248,7 +1248,7 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), constexpr=tu
|
||||
if warm_cache_only:
|
||||
return # load_binary() requires a valid cuda context
|
||||
|
||||
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir, device)
|
||||
return CompiledKernel(fn, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir, device)
|
||||
|
||||
|
||||
class CompiledKernel:
|
||||
@@ -1257,11 +1257,12 @@ class CompiledKernel:
|
||||
launch_enter_hook = None
|
||||
launch_exit_hook = None
|
||||
|
||||
def __init__(self, fn_name, so_path, cache_dir, device):
|
||||
def __init__(self, fn, so_path, cache_dir, device):
|
||||
# initialize launcher
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
fn_name = fn.__name__
|
||||
spec.loader.exec_module(mod)
|
||||
self.c_wrapper = getattr(mod, "launch")
|
||||
# initialize metadata
|
||||
@@ -1283,7 +1284,7 @@ class CompiledKernel:
|
||||
self.asm["ttir"] = f.read()
|
||||
|
||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
self.fn_name = fn_name
|
||||
self.fn = fn
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
self.n_regs = n_regs
|
||||
|
Reference in New Issue
Block a user