codegen for If

This commit is contained in:
Yan Da
2022-04-04 12:58:37 +08:00
parent 9df899b291
commit 0f96da336a

View File

@@ -326,13 +326,18 @@ class CodeGenerator(ast.NodeVisitor):
if then_defs or node.orelse:
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
then_yield_op = if_op.get_then_yield()
else_yield_op = if_op.get_else_yield()
for name in names:
then_yield_op.append_operand(then_defs[name].handle)
else_yield_op.append_operand(else_defs[name].handle)
then_block.merge_block_before(if_op.get_then_block())
self.builder.set_insertion_point_to_end(if_op.get_then_block())
self.builder.create_yield_op([y.handle for n, y in then_defs.items()])
if not node.orelse:
else_block = if_op.get_else_block()
else:
else_block.merge_block_before(if_op.get_else_block())
self.builder.set_insertion_point_to_end(if_op.get_else_block())
self.builder.create_yield_op([y.handle for n, y in else_defs.items()])
else:
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
then_block.merge_block_before(if_op.get_then_block())
self.builder.set_insertion_point_to_end(ip_block)
# restore values in the parent scope
@@ -344,7 +349,6 @@ class CodeGenerator(ast.NodeVisitor):
self.lscope[name] = new_tensor
self.local_defs[name] = new_tensor
else:
if isinstance(cond, triton.language.constexpr):
cond = cond.value