Pass fn to CompiliedKernel
This commit is contained in:
@@ -1248,7 +1248,7 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), constexpr=tu
|
|||||||
if warm_cache_only:
|
if warm_cache_only:
|
||||||
return # load_binary() requires a valid cuda context
|
return # load_binary() requires a valid cuda context
|
||||||
|
|
||||||
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir, device)
|
return CompiledKernel(fn, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir, device)
|
||||||
|
|
||||||
|
|
||||||
class CompiledKernel:
|
class CompiledKernel:
|
||||||
@@ -1257,11 +1257,12 @@ class CompiledKernel:
|
|||||||
launch_enter_hook = None
|
launch_enter_hook = None
|
||||||
launch_exit_hook = None
|
launch_exit_hook = None
|
||||||
|
|
||||||
def __init__(self, fn_name, so_path, cache_dir, device):
|
def __init__(self, fn, so_path, cache_dir, device):
|
||||||
# initialize launcher
|
# initialize launcher
|
||||||
import importlib.util
|
import importlib.util
|
||||||
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
||||||
mod = importlib.util.module_from_spec(spec)
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
fn_name = fn.__name__
|
||||||
spec.loader.exec_module(mod)
|
spec.loader.exec_module(mod)
|
||||||
self.c_wrapper = getattr(mod, "launch")
|
self.c_wrapper = getattr(mod, "launch")
|
||||||
# initialize metadata
|
# initialize metadata
|
||||||
@@ -1283,7 +1284,7 @@ class CompiledKernel:
|
|||||||
self.asm["ttir"] = f.read()
|
self.asm["ttir"] = f.read()
|
||||||
|
|
||||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||||
self.fn_name = fn_name
|
self.fn = fn
|
||||||
self.cu_module = mod
|
self.cu_module = mod
|
||||||
self.cu_function = func
|
self.cu_function = func
|
||||||
self.n_regs = n_regs
|
self.n_regs = n_regs
|
||||||
|
Reference in New Issue
Block a user