Pass function arguments to JITFunction's call_hook
This commit is contained in:
@@ -198,7 +198,7 @@ class JITFunction(KernelInterface[T]):
|
|||||||
constants = {i: k for i, k in zip(self.constexprs, constexpr_key)}
|
constants = {i: k for i, k in zip(self.constexprs, constexpr_key)}
|
||||||
return constants
|
return constants
|
||||||
|
|
||||||
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs, args):
|
||||||
if JITFunction.cache_hook is None:
|
if JITFunction.cache_hook is None:
|
||||||
return False
|
return False
|
||||||
name = self.fn.__name__
|
name = self.fn.__name__
|
||||||
@@ -217,7 +217,7 @@ class JITFunction(KernelInterface[T]):
|
|||||||
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
|
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
|
||||||
configs=configs)
|
configs=configs)
|
||||||
|
|
||||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False, args=args, arg_names=self.arg_names)
|
||||||
|
|
||||||
def _make_launcher(self):
|
def _make_launcher(self):
|
||||||
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||||
@@ -277,7 +277,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
for i, arg in constants.items():
|
for i, arg in constants.items():
|
||||||
if callable(arg):
|
if callable(arg):
|
||||||
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
|
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
|
||||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs, args):
|
||||||
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs)
|
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs)
|
||||||
if not warmup:
|
if not warmup:
|
||||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
|
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
|
||||||
|
Reference in New Issue
Block a user