Manage insertion block with context manager
This commit is contained in:
@@ -22,6 +22,22 @@ import triton
|
|||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
from .tools.disasm import extract
|
from .tools.disasm import extract
|
||||||
|
|
||||||
|
class enter_sub_region:
|
||||||
|
def __init__(self, generator: CodeGenerator):
|
||||||
|
self.generator = generator
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
# record lscope & local_defs in the parent scope
|
||||||
|
self.liveins = self.generator.lscope.copy()
|
||||||
|
self.prev_defs = self.generator.local_defs.copy()
|
||||||
|
self.generator.local_defs = {}
|
||||||
|
self.insert_block = self.generator.builder.get_insertion_block()
|
||||||
|
return self.liveins, self.insert_block
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
self.generator.builder.set_insertion_point_to_end(self.insert_block)
|
||||||
|
self.generator.lscope = self.liveins
|
||||||
|
self.generator.local_defs = self.prev_defs
|
||||||
|
|
||||||
class CodeGenerator(ast.NodeVisitor):
|
class CodeGenerator(ast.NodeVisitor):
|
||||||
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
||||||
@@ -267,12 +283,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
cond = self.visit(node.test)
|
cond = self.visit(node.test)
|
||||||
if isinstance(cond, triton.language.tensor):
|
if isinstance(cond, triton.language.tensor):
|
||||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||||
# record lscope & local_defs in the parent scope
|
with enter_sub_region(self) as sr:
|
||||||
liveins = self.lscope.copy()
|
liveins, ip_block = sr
|
||||||
parent_defs = self.local_defs.copy()
|
|
||||||
self.local_defs = {}
|
|
||||||
|
|
||||||
ip_block = self.builder.get_insertion_block()
|
|
||||||
|
|
||||||
then_block = self.builder.create_block()
|
then_block = self.builder.create_block()
|
||||||
self.builder.set_insertion_point_to_start(then_block)
|
self.builder.set_insertion_point_to_start(then_block)
|
||||||
@@ -322,10 +334,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
||||||
then_block.merge_block_before(if_op.get_then_block())
|
then_block.merge_block_before(if_op.get_then_block())
|
||||||
|
|
||||||
self.builder.set_insertion_point_to_end(ip_block)
|
|
||||||
# restore values in the parent scope
|
|
||||||
self.lscope = liveins
|
|
||||||
self.local_defs = parent_defs
|
|
||||||
# update values yielded by IfOp
|
# update values yielded by IfOp
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i])
|
new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i])
|
||||||
@@ -396,11 +404,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
return getattr(op, fn)()
|
return getattr(op, fn)()
|
||||||
|
|
||||||
def visit_While(self, node):
|
def visit_While(self, node):
|
||||||
liveins = self.lscope.copy()
|
with enter_sub_region(self) as sr:
|
||||||
prev_defs = self.local_defs.copy()
|
liveins, insert_block = sr
|
||||||
self.local_defs = {}
|
|
||||||
|
|
||||||
insert_block = self.builder.get_insertion_block()
|
|
||||||
|
|
||||||
# condtion (the before region)
|
# condtion (the before region)
|
||||||
cond_block = self.builder.create_block()
|
cond_block = self.builder.create_block()
|
||||||
@@ -452,9 +457,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
||||||
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
||||||
|
|
||||||
self.builder.set_insertion_point_to_end(insert_block)
|
|
||||||
self.lscope = liveins
|
|
||||||
self.local_defs = prev_defs
|
|
||||||
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
|
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
|
||||||
@@ -508,18 +510,13 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ub = self.builder.create_to_index(ub)
|
ub = self.builder.create_to_index(ub)
|
||||||
step = self.builder.create_to_index(step)
|
step = self.builder.create_to_index(step)
|
||||||
|
|
||||||
# cache current insertion block
|
with enter_sub_region(self) as sr:
|
||||||
insert_block = self.builder.get_insertion_block()
|
liveins, insert_block = sr
|
||||||
|
|
||||||
# create loop body block
|
# create loop body block
|
||||||
block = self.builder.create_block()
|
block = self.builder.create_block()
|
||||||
self.builder.set_insertion_point_to_start(block)
|
self.builder.set_insertion_point_to_start(block)
|
||||||
|
|
||||||
# record lscope & local_defs in the parent scope
|
|
||||||
liveins = self.lscope.copy()
|
|
||||||
prev_defs = self.local_defs.copy()
|
|
||||||
self.local_defs = {}
|
|
||||||
|
|
||||||
# visit loop body
|
# visit loop body
|
||||||
self.visit_compound_statement(node.body)
|
self.visit_compound_statement(node.body)
|
||||||
|
|
||||||
@@ -530,7 +527,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
names = []
|
names = []
|
||||||
for name in self.local_defs:
|
for name in self.local_defs:
|
||||||
if name in liveins:
|
if name in liveins:
|
||||||
assert self.is_triton_tensor(self.local_defs[name])
|
assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor'
|
||||||
assert self.is_triton_tensor(liveins[name])
|
assert self.is_triton_tensor(liveins[name])
|
||||||
if self.local_defs[name].type == liveins[name].type:
|
if self.local_defs[name].type == liveins[name].type:
|
||||||
names.append(name)
|
names.append(name)
|
||||||
@@ -550,10 +547,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
# arg0 is the induction variable
|
# arg0 is the induction variable
|
||||||
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
||||||
|
|
||||||
self.builder.set_insertion_point_to_end(insert_block)
|
|
||||||
|
|
||||||
self.lscope = liveins
|
|
||||||
self.local_defs = prev_defs
|
|
||||||
# update lscope & local_defs (ForOp defines new values)
|
# update lscope & local_defs (ForOp defines new values)
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||||
|
Reference in New Issue
Block a user