diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 5bdeb4347..3296883b1 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -826,7 +826,7 @@ def make_triton_ir(fn, signature, specialization, constants): gscope = fn.__globals__.copy() function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)]) tys = list(signature.values()) - new_constants = {k: True if tys[k] == "i1" else 1 for k in specialization.equal_to_1} + new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in specialization.equal_to_1} new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16} all_constants = constants.copy() all_constants.update(new_constants) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 4e9e58b06..3cf9f836e 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -259,12 +259,12 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage except KeyError: # build dict of constant values args = [{args}] - configs = self._get_config(*args), + all_args = {', '.join([f'{arg}' for arg in self.arg_names])}, + configs = self._get_config(*all_args), constants = self._make_constants(constexpr_key) - constants.update({{i: None for i, arg in enumerate(args) if arg is None}}) + constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}}) constants.update({{i: 1 for i in configs[0].equal_to_1}}) # build kernel signature -- doesn't include specialized arguments - all_args = {', '.join([f'{arg}' for arg in self.arg_names])}, signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }} # build stub signature -- includes arguments that are specialized for i, arg in constants.items():