Manage insertion block with context manager

This commit is contained in:
Yan Da
2022-04-10 15:02:12 +08:00
parent aa6e086881
commit 9c7b3d5173

View File

@@ -22,6 +22,22 @@ import triton
import triton._C.libtriton.triton as _triton
from .tools.disasm import extract
class enter_sub_region:
def __init__(self, generator: CodeGenerator):
self.generator = generator
def __enter__(self):
# record lscope & local_defs in the parent scope
self.liveins = self.generator.lscope.copy()
self.prev_defs = self.generator.local_defs.copy()
self.generator.local_defs = {}
self.insert_block = self.generator.builder.get_insertion_block()
return self.liveins, self.insert_block
def __exit__(self, *args, **kwargs):
self.generator.builder.set_insertion_point_to_end(self.insert_block)
self.generator.lscope = self.liveins
self.generator.local_defs = self.prev_defs
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
@@ -267,12 +283,8 @@ class CodeGenerator(ast.NodeVisitor):
cond = self.visit(node.test)
if isinstance(cond, triton.language.tensor):
cond = cond.to(triton.language.int1, _builder=self.builder)
# record lscope & local_defs in the parent scope
liveins = self.lscope.copy()
parent_defs = self.local_defs.copy()
self.local_defs = {}
ip_block = self.builder.get_insertion_block()
with enter_sub_region(self) as sr:
liveins, ip_block = sr
then_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(then_block)
@@ -322,10 +334,6 @@ class CodeGenerator(ast.NodeVisitor):
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
then_block.merge_block_before(if_op.get_then_block())
self.builder.set_insertion_point_to_end(ip_block)
# restore values in the parent scope
self.lscope = liveins
self.local_defs = parent_defs
# update values yielded by IfOp
for i, name in enumerate(names):
new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i])
@@ -396,11 +404,8 @@ class CodeGenerator(ast.NodeVisitor):
return getattr(op, fn)()
def visit_While(self, node):
liveins = self.lscope.copy()
prev_defs = self.local_defs.copy()
self.local_defs = {}
insert_block = self.builder.get_insertion_block()
with enter_sub_region(self) as sr:
liveins, insert_block = sr
# condtion (the before region)
cond_block = self.builder.create_block()
@@ -452,9 +457,6 @@ class CodeGenerator(ast.NodeVisitor):
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
self.builder.set_insertion_point_to_end(insert_block)
self.lscope = liveins
self.local_defs = prev_defs
# WhileOp defines new values, update the symbol table (lscope, local_defs)
for i, name in enumerate(names):
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
@@ -508,18 +510,13 @@ class CodeGenerator(ast.NodeVisitor):
ub = self.builder.create_to_index(ub)
step = self.builder.create_to_index(step)
# cache current insertion block
insert_block = self.builder.get_insertion_block()
with enter_sub_region(self) as sr:
liveins, insert_block = sr
# create loop body block
block = self.builder.create_block()
self.builder.set_insertion_point_to_start(block)
# record lscope & local_defs in the parent scope
liveins = self.lscope.copy()
prev_defs = self.local_defs.copy()
self.local_defs = {}
# visit loop body
self.visit_compound_statement(node.body)
@@ -530,7 +527,7 @@ class CodeGenerator(ast.NodeVisitor):
names = []
for name in self.local_defs:
if name in liveins:
assert self.is_triton_tensor(self.local_defs[name])
assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor'
assert self.is_triton_tensor(liveins[name])
if self.local_defs[name].type == liveins[name].type:
names.append(name)
@@ -550,10 +547,6 @@ class CodeGenerator(ast.NodeVisitor):
# arg0 is the induction variable
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
self.builder.set_insertion_point_to_end(insert_block)
self.lscope = liveins
self.local_defs = prev_defs
# update lscope & local_defs (ForOp defines new values)
for i, name in enumerate(names):
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)