[RUNTIME] Fixed JIT bug that leg some constexpr values to be overriden by specialization parameters (#742)
This commit is contained in:
@@ -826,7 +826,7 @@ def make_triton_ir(fn, signature, specialization, constants):
|
||||
gscope = fn.__globals__.copy()
|
||||
function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)])
|
||||
tys = list(signature.values())
|
||||
new_constants = {k: True if tys[k] == "i1" else 1 for k in specialization.equal_to_1}
|
||||
new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in specialization.equal_to_1}
|
||||
new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16}
|
||||
all_constants = constants.copy()
|
||||
all_constants.update(new_constants)
|
||||
|
@@ -259,12 +259,12 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
except KeyError:
|
||||
# build dict of constant values
|
||||
args = [{args}]
|
||||
configs = self._get_config(*args),
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
||||
configs = self._get_config(*all_args),
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({{i: None for i, arg in enumerate(args) if arg is None}})
|
||||
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
|
||||
constants.update({{i: 1 for i in configs[0].equal_to_1}})
|
||||
# build kernel signature -- doesn't include specialized arguments
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
||||
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
|
||||
# build stub signature -- includes arguments that are specialized
|
||||
for i, arg in constants.items():
|
||||
|
Reference in New Issue
Block a user