fix issues in visit_If

This commit is contained in:
Yan Da
2022-04-10 16:28:45 +08:00
parent fcbbb3c10e
commit 4eb062f313
2 changed files with 12 additions and 7 deletions

View File

@@ -291,6 +291,10 @@ class CodeGenerator(ast.NodeVisitor):
self.visit_compound_statement(node.body)
then_defs = self.local_defs.copy()
# when need an else block when:
# 1. we have an orelse node
# or
# 2. the then block defines new variable
if then_defs or node.orelse:
if node.orelse:
self.lscope = liveins
@@ -319,18 +323,18 @@ class CodeGenerator(ast.NodeVisitor):
self.builder.set_insertion_point_to_end(ip_block)
if then_defs or node.orelse:
if then_defs or node.orelse: # with else block
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
then_block.merge_block_before(if_op.get_then_block())
self.builder.set_insertion_point_to_end(if_op.get_then_block())
self.builder.create_yield_op([y.handle for n, y in then_defs.items()])
self.builder.create_yield_op([then_defs[n].handle for n in names])
if not node.orelse:
else_block = if_op.get_else_block()
else:
else_block.merge_block_before(if_op.get_else_block())
self.builder.set_insertion_point_to_end(if_op.get_else_block())
self.builder.create_yield_op([y.handle for n, y in else_defs.items()])
else:
self.builder.create_yield_op([else_defs[n].handle for n in names])
else: # no else block
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
then_block.merge_block_before(if_op.get_then_block())
@@ -425,7 +429,7 @@ class CodeGenerator(ast.NodeVisitor):
yields = []
for name in loop_defs:
if name in liveins:
# We should not def new constexpr (?)
# We should not def new constexpr
assert self.is_triton_tensor(loop_defs[name])
assert self.is_triton_tensor(liveins[name])
if loop_defs[name].type == liveins[name].type: