Manage insertion block with context manager
This commit is contained in:
@@ -22,6 +22,22 @@ import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .tools.disasm import extract
|
||||
|
||||
class enter_sub_region:
|
||||
def __init__(self, generator: CodeGenerator):
|
||||
self.generator = generator
|
||||
|
||||
def __enter__(self):
|
||||
# record lscope & local_defs in the parent scope
|
||||
self.liveins = self.generator.lscope.copy()
|
||||
self.prev_defs = self.generator.local_defs.copy()
|
||||
self.generator.local_defs = {}
|
||||
self.insert_block = self.generator.builder.get_insertion_block()
|
||||
return self.liveins, self.insert_block
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self.generator.builder.set_insertion_point_to_end(self.insert_block)
|
||||
self.generator.lscope = self.liveins
|
||||
self.generator.local_defs = self.prev_defs
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
||||
@@ -267,12 +283,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
cond = self.visit(node.test)
|
||||
if isinstance(cond, triton.language.tensor):
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
# record lscope & local_defs in the parent scope
|
||||
liveins = self.lscope.copy()
|
||||
parent_defs = self.local_defs.copy()
|
||||
self.local_defs = {}
|
||||
|
||||
ip_block = self.builder.get_insertion_block()
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, ip_block = sr
|
||||
|
||||
then_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(then_block)
|
||||
@@ -322,10 +334,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
|
||||
self.builder.set_insertion_point_to_end(ip_block)
|
||||
# restore values in the parent scope
|
||||
self.lscope = liveins
|
||||
self.local_defs = parent_defs
|
||||
# update values yielded by IfOp
|
||||
for i, name in enumerate(names):
|
||||
new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i])
|
||||
@@ -396,11 +404,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return getattr(op, fn)()
|
||||
|
||||
def visit_While(self, node):
|
||||
liveins = self.lscope.copy()
|
||||
prev_defs = self.local_defs.copy()
|
||||
self.local_defs = {}
|
||||
|
||||
insert_block = self.builder.get_insertion_block()
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, insert_block = sr
|
||||
|
||||
# condtion (the before region)
|
||||
cond_block = self.builder.create_block()
|
||||
@@ -452,9 +457,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
||||
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
self.lscope = liveins
|
||||
self.local_defs = prev_defs
|
||||
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
||||
for i, name in enumerate(names):
|
||||
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
|
||||
@@ -508,18 +510,13 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ub = self.builder.create_to_index(ub)
|
||||
step = self.builder.create_to_index(step)
|
||||
|
||||
# cache current insertion block
|
||||
insert_block = self.builder.get_insertion_block()
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, insert_block = sr
|
||||
|
||||
# create loop body block
|
||||
block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(block)
|
||||
|
||||
# record lscope & local_defs in the parent scope
|
||||
liveins = self.lscope.copy()
|
||||
prev_defs = self.local_defs.copy()
|
||||
self.local_defs = {}
|
||||
|
||||
# visit loop body
|
||||
self.visit_compound_statement(node.body)
|
||||
|
||||
@@ -530,7 +527,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
names = []
|
||||
for name in self.local_defs:
|
||||
if name in liveins:
|
||||
assert self.is_triton_tensor(self.local_defs[name])
|
||||
assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor'
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
if self.local_defs[name].type == liveins[name].type:
|
||||
names.append(name)
|
||||
@@ -550,10 +547,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# arg0 is the induction variable
|
||||
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
|
||||
self.lscope = liveins
|
||||
self.local_defs = prev_defs
|
||||
# update lscope & local_defs (ForOp defines new values)
|
||||
for i, name in enumerate(names):
|
||||
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||
|
Reference in New Issue
Block a user