[RUNTIME] Fixed JIT bug that leg some constexpr values to be overriden by specialization parameters (#742)
This commit is contained in:
@@ -826,7 +826,7 @@ def make_triton_ir(fn, signature, specialization, constants):
|
|||||||
gscope = fn.__globals__.copy()
|
gscope = fn.__globals__.copy()
|
||||||
function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)])
|
function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)])
|
||||||
tys = list(signature.values())
|
tys = list(signature.values())
|
||||||
new_constants = {k: True if tys[k] == "i1" else 1 for k in specialization.equal_to_1}
|
new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in specialization.equal_to_1}
|
||||||
new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16}
|
new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16}
|
||||||
all_constants = constants.copy()
|
all_constants = constants.copy()
|
||||||
all_constants.update(new_constants)
|
all_constants.update(new_constants)
|
||||||
|
@@ -259,12 +259,12 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
# build dict of constant values
|
# build dict of constant values
|
||||||
args = [{args}]
|
args = [{args}]
|
||||||
configs = self._get_config(*args),
|
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
||||||
|
configs = self._get_config(*all_args),
|
||||||
constants = self._make_constants(constexpr_key)
|
constants = self._make_constants(constexpr_key)
|
||||||
constants.update({{i: None for i, arg in enumerate(args) if arg is None}})
|
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
|
||||||
constants.update({{i: 1 for i in configs[0].equal_to_1}})
|
constants.update({{i: 1 for i in configs[0].equal_to_1}})
|
||||||
# build kernel signature -- doesn't include specialized arguments
|
# build kernel signature -- doesn't include specialized arguments
|
||||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
|
||||||
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
|
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
|
||||||
# build stub signature -- includes arguments that are specialized
|
# build stub signature -- includes arguments that are specialized
|
||||||
for i, arg in constants.items():
|
for i, arg in constants.items():
|
||||||
|
Reference in New Issue
Block a user