codegen for If
This commit is contained in:
@@ -326,13 +326,18 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
if then_defs or node.orelse:
|
if then_defs or node.orelse:
|
||||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
||||||
then_yield_op = if_op.get_then_yield()
|
then_block.merge_block_before(if_op.get_then_block())
|
||||||
else_yield_op = if_op.get_else_yield()
|
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||||
for name in names:
|
self.builder.create_yield_op([y.handle for n, y in then_defs.items()])
|
||||||
then_yield_op.append_operand(then_defs[name].handle)
|
if not node.orelse:
|
||||||
else_yield_op.append_operand(else_defs[name].handle)
|
else_block = if_op.get_else_block()
|
||||||
|
else:
|
||||||
|
else_block.merge_block_before(if_op.get_else_block())
|
||||||
|
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
||||||
|
self.builder.create_yield_op([y.handle for n, y in else_defs.items()])
|
||||||
else:
|
else:
|
||||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
||||||
|
then_block.merge_block_before(if_op.get_then_block())
|
||||||
|
|
||||||
self.builder.set_insertion_point_to_end(ip_block)
|
self.builder.set_insertion_point_to_end(ip_block)
|
||||||
# restore values in the parent scope
|
# restore values in the parent scope
|
||||||
@@ -344,7 +349,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
self.lscope[name] = new_tensor
|
self.lscope[name] = new_tensor
|
||||||
self.local_defs[name] = new_tensor
|
self.local_defs[name] = new_tensor
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if isinstance(cond, triton.language.constexpr):
|
if isinstance(cond, triton.language.constexpr):
|
||||||
cond = cond.value
|
cond = cond.value
|
||||||
|
Reference in New Issue
Block a user