codegen for If
This commit is contained in:
@@ -326,13 +326,18 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
if then_defs or node.orelse:
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
||||
then_yield_op = if_op.get_then_yield()
|
||||
else_yield_op = if_op.get_else_yield()
|
||||
for name in names:
|
||||
then_yield_op.append_operand(then_defs[name].handle)
|
||||
else_yield_op.append_operand(else_defs[name].handle)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||
self.builder.create_yield_op([y.handle for n, y in then_defs.items()])
|
||||
if not node.orelse:
|
||||
else_block = if_op.get_else_block()
|
||||
else:
|
||||
else_block.merge_block_before(if_op.get_else_block())
|
||||
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
||||
self.builder.create_yield_op([y.handle for n, y in else_defs.items()])
|
||||
else:
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
|
||||
self.builder.set_insertion_point_to_end(ip_block)
|
||||
# restore values in the parent scope
|
||||
@@ -344,7 +349,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.lscope[name] = new_tensor
|
||||
self.local_defs[name] = new_tensor
|
||||
|
||||
|
||||
else:
|
||||
if isinstance(cond, triton.language.constexpr):
|
||||
cond = cond.value
|
||||
|
Reference in New Issue
Block a user