fix issues in visit_If

This commit is contained in:
Yan Da
2022-04-10 16:28:45 +08:00
parent fcbbb3c10e
commit 4eb062f313
2 changed files with 12 additions and 7 deletions

View File

@@ -291,6 +291,10 @@ class CodeGenerator(ast.NodeVisitor):
self.visit_compound_statement(node.body) self.visit_compound_statement(node.body)
then_defs = self.local_defs.copy() then_defs = self.local_defs.copy()
# when need an else block when:
# 1. we have an orelse node
# or
# 2. the then block defines new variable
if then_defs or node.orelse: if then_defs or node.orelse:
if node.orelse: if node.orelse:
self.lscope = liveins self.lscope = liveins
@@ -319,18 +323,18 @@ class CodeGenerator(ast.NodeVisitor):
self.builder.set_insertion_point_to_end(ip_block) self.builder.set_insertion_point_to_end(ip_block)
if then_defs or node.orelse: if then_defs or node.orelse: # with else block
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
then_block.merge_block_before(if_op.get_then_block()) then_block.merge_block_before(if_op.get_then_block())
self.builder.set_insertion_point_to_end(if_op.get_then_block()) self.builder.set_insertion_point_to_end(if_op.get_then_block())
self.builder.create_yield_op([y.handle for n, y in then_defs.items()]) self.builder.create_yield_op([then_defs[n].handle for n in names])
if not node.orelse: if not node.orelse:
else_block = if_op.get_else_block() else_block = if_op.get_else_block()
else: else:
else_block.merge_block_before(if_op.get_else_block()) else_block.merge_block_before(if_op.get_else_block())
self.builder.set_insertion_point_to_end(if_op.get_else_block()) self.builder.set_insertion_point_to_end(if_op.get_else_block())
self.builder.create_yield_op([y.handle for n, y in else_defs.items()]) self.builder.create_yield_op([else_defs[n].handle for n in names])
else: else: # no else block
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False) if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
then_block.merge_block_before(if_op.get_then_block()) then_block.merge_block_before(if_op.get_then_block())
@@ -425,7 +429,7 @@ class CodeGenerator(ast.NodeVisitor):
yields = [] yields = []
for name in loop_defs: for name in loop_defs:
if name in liveins: if name in liveins:
# We should not def new constexpr (?) # We should not def new constexpr
assert self.is_triton_tensor(loop_defs[name]) assert self.is_triton_tensor(loop_defs[name])
assert self.is_triton_tensor(liveins[name]) assert self.is_triton_tensor(liveins[name])
if loop_defs[name].type == liveins[name].type: if loop_defs[name].type == liveins[name].type:

View File

@@ -26,12 +26,13 @@ def nested_cf(X, lb, ub, Z):
if lb < ub: if lb < ub:
for z in range(0, Z): for z in range(0, Z):
a += 2.0 a += 2.0
# a += 2.0
else: else:
# a *= 2.0
while a < 1.2: while a < 1.2:
a *= 2.0 a *= 2.0
for _ in range(0, Z, 2):
a *= -3.3
a -= 1.0 a -= 1.0
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,)) mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
assert mod.verify(), mod.str() assert mod.verify(), mod.str()
mod.dump()