More progress on WhileOp codegen

This commit is contained in:
Yan Da
2022-04-05 15:55:48 +08:00
parent 76d9249724
commit c7ad928e60
5 changed files with 145 additions and 45 deletions

View File

@@ -46,6 +46,7 @@ class CodeGenerator(ast.NodeVisitor):
# SSA-construction
# name => triton.language.tensor
self.local_defs: Dict[str, triton.language.tensor] = {}
self.global_uses: Dict[str, triton.language.tensor] = {}
def get_value(self, name):
''' This function:
@@ -57,6 +58,8 @@ class CodeGenerator(ast.NodeVisitor):
ret = None
if name in self.lscope:
ret = self.lscope[name]
if name not in self.local_defs:
self.global_uses[name] = ret
# search node.id in global scope
elif name in self.gscope:
ret = self.gscope[name]
@@ -263,27 +266,6 @@ class CodeGenerator(ast.NodeVisitor):
def visit_If(self, node):
cond = self.visit(node.test)
if isinstance(cond, triton.language.tensor):
# cond = cond.to(triton.language.int1, _builder=self.builder)
# current_bb = self.builder.get_insertion_block()
# then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
# else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
# endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
# if else_bb:
# self.builder.cond_br(cond.handle, then_bb, else_bb)
# else:
# self.builder.cond_br(cond.handle, then_bb, endif_bb)
# self.builder.set_insert_block(then_bb)
# is_terminator = self.visit_compound_statement(node.body)
# # TODO: last statement is a terminator?
# if not is_terminator:
# self.builder.br(endif_bb)
# if else_bb:
# self.builder.set_insert_block(else_bb)
# is_terminator = self.visit_compound_statement(node.orelse)
# # TODO: last statement is a terminator?
# if not is_terminator:
# self.builder.br(endif_bb)
# self.builder.set_insert_block(endif_bb)
cond = cond.to(triton.language.int1, _builder=self.builder)
liveins = self.lscope.copy()
parent_defs = self.local_defs.copy()
@@ -413,22 +395,64 @@ class CodeGenerator(ast.NodeVisitor):
return getattr(op, fn)()
def visit_While(self, node):
current_bb = self.builder.get_insertion_block()
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
liveins = self.lscope.copy()
prev_defs = self.local_defs.copy()
self.local_defs = {}
def continue_fn():
cond = self.visit(node.test)
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
insert_block = self.builder.get_insertion_block()
continue_fn()
self.builder.set_insert_block(loop_bb)
# condtion (the before region)
cond_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(cond_block)
cond = self.visit(node.test)
# loop body (the after region)
loop_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(loop_block)
self.visit_compound_statement(node.body)
continue_fn()
stop_bb = self.builder.get_insertion_block()
self.builder.set_insert_block(next_bb)
loop_defs = self.local_defs
# collect loop-carried values
names = []
ret_types = []
init_args = []
yields = []
for name in loop_defs:
if name in liveins:
# We should not def new constexpr (?)
assert self.is_triton_tensor(loop_defs[name])
assert self.is_triton_tensor(liveins[name])
if loop_defs[name].type == liveins[name].type:
# these are loop-carried values
names.append(name)
ret_types.append(loop_defs[name].type.to_ir(self.builder))
init_args.append(liveins[name])
yields.append(loop_defs[name])
self.builder.set_insertion_point_to_end(insert_block)
while_op = self.builder.create_while_op(ret_types, init_args)
# merge the condition region
before_block = self.builder.create_block_with_parent(while_op.get_before())
cond_block.merge_block_before(before_block)
self.builder.set_insertion_point_to_end(before_block)
self.builder.create_condtion_op(cond.handle, [])
# merge the loop body
after_block = self.builder.create_block_with_parent(while_op.get_after())
loop_block.merge_block_before(after_block)
self.builder.set_insertion_point_to_end(after_block)
self.builder.create_yield_op([y.handle for y in yields])
self.builder.set_insertion_point_to_end(insert_block)
self.lscope = liveins
self.local_defs = prev_defs
# WhileOp defines new values, update the symbol table (lscope, local_defs)
for i, name in enumerate(names):
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
self.lscope[name] = new_def
self.local_defs[name] = new_def
for stmt in node.orelse:
assert False, "Not implemented"
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Subscript(self, node):