From 0f96da336a77eb77561a177403ecc0353aa0c6c8 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Mon, 4 Apr 2022 12:58:37 +0800 Subject: [PATCH] codegen for If --- python/triton/code_gen.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index cf1594757..e65a21b5c 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -326,13 +326,18 @@ class CodeGenerator(ast.NodeVisitor): if then_defs or node.orelse: if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) - then_yield_op = if_op.get_then_yield() - else_yield_op = if_op.get_else_yield() - for name in names: - then_yield_op.append_operand(then_defs[name].handle) - else_yield_op.append_operand(else_defs[name].handle) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([y.handle for n, y in then_defs.items()]) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([y.handle for n, y in else_defs.items()]) else: if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False) + then_block.merge_block_before(if_op.get_then_block()) self.builder.set_insertion_point_to_end(ip_block) # restore values in the parent scope @@ -344,7 +349,6 @@ class CodeGenerator(ast.NodeVisitor): self.lscope[name] = new_tensor self.local_defs[name] = new_tensor - else: if isinstance(cond, triton.language.constexpr): cond = cond.value