[TRITON-MLIR][FRONTEND]fix scf.if to run through layernorm tutorial (#938)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -359,7 +359,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, ip_block = sr
|
||||
|
||||
liveins_copy = liveins.copy()
|
||||
then_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(then_block)
|
||||
self.visit_compound_statement(node.body)
|
||||
@@ -394,7 +394,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if then_defs[then_name].type == else_defs[else_name].type:
|
||||
names.append(then_name)
|
||||
ret_types.append(then_defs[then_name].type)
|
||||
|
||||
|
||||
# defined in else block but not in then block
|
||||
# to find in parent scope and yield them
|
||||
for else_name in else_defs:
|
||||
if else_name in liveins and else_name not in then_defs:
|
||||
if else_defs[else_name].type == liveins[else_name].type:
|
||||
names.append(else_name)
|
||||
ret_types.append(else_defs[else_name].type)
|
||||
then_defs[else_name] = liveins_copy[else_name]
|
||||
self.builder.set_insertion_point_to_end(ip_block)
|
||||
|
||||
if then_defs or node.orelse: # with else block
|
||||
|
Reference in New Issue
Block a user