Pass function arguments to JITFunction's call_hook

This commit is contained in:
Da Yan
2023-01-09 20:53:20 +00:00
parent 0f5c6e619c
commit d9392c0fd3

View File

@@ -198,7 +198,7 @@ class JITFunction(KernelInterface[T]):
constants = {i: k for i, k in zip(self.constexprs, constexpr_key)}
return constants
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs, args):
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
@@ -217,7 +217,7 @@ class JITFunction(KernelInterface[T]):
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
configs=configs)
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False, args=args, arg_names=self.arg_names)
def _make_launcher(self):
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
@@ -277,7 +277,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs, args):
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)