[RUNTIME] Fixed JIT bug that leg some constexpr values to be overriden by specialization parameters (#742)

This commit is contained in:
Philippe Tillet
2022-10-05 11:00:32 -07:00
committed by GitHub
parent 77c752dc78
commit bdfdb9a1d2
2 changed files with 4 additions and 4 deletions

View File

@@ -826,7 +826,7 @@ def make_triton_ir(fn, signature, specialization, constants):
gscope = fn.__globals__.copy()
function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)])
tys = list(signature.values())
new_constants = {k: True if tys[k] == "i1" else 1 for k in specialization.equal_to_1}
new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in specialization.equal_to_1}
new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16}
all_constants = constants.copy()
all_constants.update(new_constants)

View File

@@ -259,12 +259,12 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
except KeyError:
# build dict of constant values
args = [{args}]
configs = self._get_config(*args),
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
configs = self._get_config(*all_args),
constants = self._make_constants(constexpr_key)
constants.update({{i: None for i, arg in enumerate(args) if arg is None}})
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
constants.update({{i: 1 for i in configs[0].equal_to_1}})
# build kernel signature -- doesn't include specialized arguments
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():