diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index a7664a896..bb781dfe6 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -291,6 +291,10 @@ class CodeGenerator(ast.NodeVisitor): self.visit_compound_statement(node.body) then_defs = self.local_defs.copy() + # when need an else block when: + # 1. we have an orelse node + # or + # 2. the then block defines new variable if then_defs or node.orelse: if node.orelse: self.lscope = liveins @@ -319,18 +323,18 @@ class CodeGenerator(ast.NodeVisitor): self.builder.set_insertion_point_to_end(ip_block) - if then_defs or node.orelse: + if then_defs or node.orelse: # with else block if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) then_block.merge_block_before(if_op.get_then_block()) self.builder.set_insertion_point_to_end(if_op.get_then_block()) - self.builder.create_yield_op([y.handle for n, y in then_defs.items()]) + self.builder.create_yield_op([then_defs[n].handle for n in names]) if not node.orelse: else_block = if_op.get_else_block() else: else_block.merge_block_before(if_op.get_else_block()) self.builder.set_insertion_point_to_end(if_op.get_else_block()) - self.builder.create_yield_op([y.handle for n, y in else_defs.items()]) - else: + self.builder.create_yield_op([else_defs[n].handle for n in names]) + else: # no else block if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False) then_block.merge_block_before(if_op.get_then_block()) @@ -425,7 +429,7 @@ class CodeGenerator(ast.NodeVisitor): yields = [] for name in loop_defs: if name in liveins: - # We should not def new constexpr (?) + # We should not def new constexpr assert self.is_triton_tensor(loop_defs[name]) assert self.is_triton_tensor(liveins[name]) if loop_defs[name].type == liveins[name].type: diff --git a/rewrite-test/jit/while.py b/rewrite-test/jit/while.py index 5e27025af..2c114cd4b 100644 --- a/rewrite-test/jit/while.py +++ b/rewrite-test/jit/while.py @@ -26,12 +26,13 @@ def nested_cf(X, lb, ub, Z): if lb < ub: for z in range(0, Z): a += 2.0 - # a += 2.0 else: - # a *= 2.0 while a < 1.2: a *= 2.0 + for _ in range(0, Z, 2): + a *= -3.3 a -= 1.0 mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,)) assert mod.verify(), mod.str() +mod.dump()