Some progress on visit_If

This commit is contained in:
Yan Da
2022-04-03 22:34:46 +08:00
parent c71c50cd0c
commit 9df899b291
2 changed files with 64 additions and 10 deletions

View File

@@ -284,15 +284,66 @@ class CodeGenerator(ast.NodeVisitor):
# if not is_terminator:
# self.builder.br(endif_bb)
# self.builder.set_insert_block(endif_bb)
parent_values = self.lscope.copy()
self.visit_compound_statement(node.body)
then_values = self.lvalues.copy()
assert node.orelse
self.lscope = parent_values
self.visit_compound_statement(node.orelse)
else_values = self.lscope.copy()
cond = cond.to(triton.language.int1, _builder=self.builder)
liveins = self.lscope.copy()
parent_defs = self.local_defs.copy()
self.local_defs = {}
ip_block = self.builder.get_insertion_block()
then_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(then_block)
self.visit_compound_statement(node.body)
then_defs = self.local_defs.copy()
if then_defs or node.orelse:
if node.orelse:
self.local_defs = {}
else_block = self.builder.create_block()
self.builder.set_insertion_point_to_end(else_block)
self.visit_compound_statement(node.orelse)
else_defs = self.local_defs.copy()
else:
# collect else_defs
else_defs = {}
for name in then_defs:
if name in liveins:
# TODO: what if this is constexpr?
assert self.is_triton_tensor(then_defs[name])
assert self.is_triton_tensor(liveins[name])
else_defs[name] = liveins[name]
# collect yields
names = []
ret_types = []
for then_name in then_defs:
for else_name in else_defs:
if then_name == else_name:
if then_defs[then_name].type == else_defs[else_name].type:
names.append(then_name)
ret_types.append(then_defs[then_name].type)
self.builder.set_insertion_point_to_end(ip_block)
if then_defs or node.orelse:
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
then_yield_op = if_op.get_then_yield()
else_yield_op = if_op.get_else_yield()
for name in names:
then_yield_op.append_operand(then_defs[name].handle)
else_yield_op.append_operand(else_defs[name].handle)
else:
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
self.builder.set_insertion_point_to_end(ip_block)
# restore values in the parent scope
self.lscope = liveins
self.local_defs = parent_defs
# update values yielded by IfOp
for i, name in enumerate(names):
new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i])
self.lscope[name] = new_tensor
self.local_defs[name] = new_tensor
self.lvalues = join_if_lvalues(then_values, else_values)
else:
if isinstance(cond, triton.language.constexpr):