[TRITON-MLIR][FRONTEND]fix scf.if to run through layernorm tutorial (#938)

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-12-02 17:45:29 +08:00
committed by GitHub
parent c280ebda1b
commit 521ff9ad74
2 changed files with 56 additions and 3 deletions

View File

@@ -359,7 +359,7 @@ class CodeGenerator(ast.NodeVisitor):
cond = cond.to(triton.language.int1, _builder=self.builder)
with enter_sub_region(self) as sr:
liveins, ip_block = sr
liveins_copy = liveins.copy()
then_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(then_block)
self.visit_compound_statement(node.body)
@@ -394,7 +394,15 @@ class CodeGenerator(ast.NodeVisitor):
if then_defs[then_name].type == else_defs[else_name].type:
names.append(then_name)
ret_types.append(then_defs[then_name].type)
# defined in else block but not in then block
# to find in parent scope and yield them
for else_name in else_defs:
if else_name in liveins and else_name not in then_defs:
if else_defs[else_name].type == liveins[else_name].type:
names.append(else_name)
ret_types.append(else_defs[else_name].type)
then_defs[else_name] = liveins_copy[else_name]
self.builder.set_insertion_point_to_end(ip_block)
if then_defs or node.orelse: # with else block