fix issues in visit_If

This commit is contained in:
Yan Da
2022-04-10 16:28:45 +08:00
parent fcbbb3c10e
commit 4eb062f313
2 changed files with 12 additions and 7 deletions

View File

@@ -291,6 +291,10 @@ class CodeGenerator(ast.NodeVisitor):
self.visit_compound_statement(node.body)
then_defs = self.local_defs.copy()
# when need an else block when:
# 1. we have an orelse node
# or
# 2. the then block defines new variable
if then_defs or node.orelse:
if node.orelse:
self.lscope = liveins
@@ -319,18 +323,18 @@ class CodeGenerator(ast.NodeVisitor):
self.builder.set_insertion_point_to_end(ip_block)
if then_defs or node.orelse:
if then_defs or node.orelse: # with else block
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
then_block.merge_block_before(if_op.get_then_block())
self.builder.set_insertion_point_to_end(if_op.get_then_block())
self.builder.create_yield_op([y.handle for n, y in then_defs.items()])
self.builder.create_yield_op([then_defs[n].handle for n in names])
if not node.orelse:
else_block = if_op.get_else_block()
else:
else_block.merge_block_before(if_op.get_else_block())
self.builder.set_insertion_point_to_end(if_op.get_else_block())
self.builder.create_yield_op([y.handle for n, y in else_defs.items()])
else:
self.builder.create_yield_op([else_defs[n].handle for n in names])
else: # no else block
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
then_block.merge_block_before(if_op.get_then_block())
@@ -425,7 +429,7 @@ class CodeGenerator(ast.NodeVisitor):
yields = []
for name in loop_defs:
if name in liveins:
# We should not def new constexpr (?)
# We should not def new constexpr
assert self.is_triton_tensor(loop_defs[name])
assert self.is_triton_tensor(liveins[name])
if loop_defs[name].type == liveins[name].type:

View File

@@ -26,12 +26,13 @@ def nested_cf(X, lb, ub, Z):
if lb < ub:
for z in range(0, Z):
a += 2.0
# a += 2.0
else:
# a *= 2.0
while a < 1.2:
a *= 2.0
for _ in range(0, Z, 2):
a *= -3.3
a -= 1.0
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
assert mod.verify(), mod.str()
mod.dump()