fix issues in visit_If
This commit is contained in:
@@ -291,6 +291,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit_compound_statement(node.body)
|
||||
then_defs = self.local_defs.copy()
|
||||
|
||||
# when need an else block when:
|
||||
# 1. we have an orelse node
|
||||
# or
|
||||
# 2. the then block defines new variable
|
||||
if then_defs or node.orelse:
|
||||
if node.orelse:
|
||||
self.lscope = liveins
|
||||
@@ -319,18 +323,18 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
self.builder.set_insertion_point_to_end(ip_block)
|
||||
|
||||
if then_defs or node.orelse:
|
||||
if then_defs or node.orelse: # with else block
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||
self.builder.create_yield_op([y.handle for n, y in then_defs.items()])
|
||||
self.builder.create_yield_op([then_defs[n].handle for n in names])
|
||||
if not node.orelse:
|
||||
else_block = if_op.get_else_block()
|
||||
else:
|
||||
else_block.merge_block_before(if_op.get_else_block())
|
||||
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
||||
self.builder.create_yield_op([y.handle for n, y in else_defs.items()])
|
||||
else:
|
||||
self.builder.create_yield_op([else_defs[n].handle for n in names])
|
||||
else: # no else block
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
|
||||
@@ -425,7 +429,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
yields = []
|
||||
for name in loop_defs:
|
||||
if name in liveins:
|
||||
# We should not def new constexpr (?)
|
||||
# We should not def new constexpr
|
||||
assert self.is_triton_tensor(loop_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
if loop_defs[name].type == liveins[name].type:
|
||||
|
@@ -26,12 +26,13 @@ def nested_cf(X, lb, ub, Z):
|
||||
if lb < ub:
|
||||
for z in range(0, Z):
|
||||
a += 2.0
|
||||
# a += 2.0
|
||||
else:
|
||||
# a *= 2.0
|
||||
while a < 1.2:
|
||||
a *= 2.0
|
||||
for _ in range(0, Z, 2):
|
||||
a *= -3.3
|
||||
a -= 1.0
|
||||
|
||||
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
|
||||
assert mod.verify(), mod.str()
|
||||
mod.dump()
|
||||
|
Reference in New Issue
Block a user