diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index ab1a5516f..40b32ca34 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -198,7 +198,7 @@ class JITFunction(KernelInterface[T]): constants = {i: k for i, k in zip(self.constexprs, constexpr_key)} return constants - def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs): + def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs, args): if JITFunction.cache_hook is None: return False name = self.fn.__name__ @@ -217,7 +217,7 @@ class JITFunction(KernelInterface[T]): num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs) - return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False) + return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False, args=args, arg_names=self.arg_names) def _make_launcher(self): regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs] @@ -277,7 +277,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage for i, arg in constants.items(): if callable(arg): raise TypeError(f"Callable constexpr at index {{i}} is not supported") - if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): + if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs, args): bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs) if not warmup: bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)