fix issues in visit_If
This commit is contained in:
@@ -291,6 +291,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
self.visit_compound_statement(node.body)
|
self.visit_compound_statement(node.body)
|
||||||
then_defs = self.local_defs.copy()
|
then_defs = self.local_defs.copy()
|
||||||
|
|
||||||
|
# when need an else block when:
|
||||||
|
# 1. we have an orelse node
|
||||||
|
# or
|
||||||
|
# 2. the then block defines new variable
|
||||||
if then_defs or node.orelse:
|
if then_defs or node.orelse:
|
||||||
if node.orelse:
|
if node.orelse:
|
||||||
self.lscope = liveins
|
self.lscope = liveins
|
||||||
@@ -319,18 +323,18 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
self.builder.set_insertion_point_to_end(ip_block)
|
self.builder.set_insertion_point_to_end(ip_block)
|
||||||
|
|
||||||
if then_defs or node.orelse:
|
if then_defs or node.orelse: # with else block
|
||||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
||||||
then_block.merge_block_before(if_op.get_then_block())
|
then_block.merge_block_before(if_op.get_then_block())
|
||||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||||
self.builder.create_yield_op([y.handle for n, y in then_defs.items()])
|
self.builder.create_yield_op([then_defs[n].handle for n in names])
|
||||||
if not node.orelse:
|
if not node.orelse:
|
||||||
else_block = if_op.get_else_block()
|
else_block = if_op.get_else_block()
|
||||||
else:
|
else:
|
||||||
else_block.merge_block_before(if_op.get_else_block())
|
else_block.merge_block_before(if_op.get_else_block())
|
||||||
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
||||||
self.builder.create_yield_op([y.handle for n, y in else_defs.items()])
|
self.builder.create_yield_op([else_defs[n].handle for n in names])
|
||||||
else:
|
else: # no else block
|
||||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
||||||
then_block.merge_block_before(if_op.get_then_block())
|
then_block.merge_block_before(if_op.get_then_block())
|
||||||
|
|
||||||
@@ -425,7 +429,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
yields = []
|
yields = []
|
||||||
for name in loop_defs:
|
for name in loop_defs:
|
||||||
if name in liveins:
|
if name in liveins:
|
||||||
# We should not def new constexpr (?)
|
# We should not def new constexpr
|
||||||
assert self.is_triton_tensor(loop_defs[name])
|
assert self.is_triton_tensor(loop_defs[name])
|
||||||
assert self.is_triton_tensor(liveins[name])
|
assert self.is_triton_tensor(liveins[name])
|
||||||
if loop_defs[name].type == liveins[name].type:
|
if loop_defs[name].type == liveins[name].type:
|
||||||
|
@@ -26,12 +26,13 @@ def nested_cf(X, lb, ub, Z):
|
|||||||
if lb < ub:
|
if lb < ub:
|
||||||
for z in range(0, Z):
|
for z in range(0, Z):
|
||||||
a += 2.0
|
a += 2.0
|
||||||
# a += 2.0
|
|
||||||
else:
|
else:
|
||||||
# a *= 2.0
|
|
||||||
while a < 1.2:
|
while a < 1.2:
|
||||||
a *= 2.0
|
a *= 2.0
|
||||||
|
for _ in range(0, Z, 2):
|
||||||
|
a *= -3.3
|
||||||
a -= 1.0
|
a -= 1.0
|
||||||
|
|
||||||
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
|
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
|
||||||
assert mod.verify(), mod.str()
|
assert mod.verify(), mod.str()
|
||||||
|
mod.dump()
|
||||||
|
Reference in New Issue
Block a user