Manage insertion block with context manager
This commit is contained in:
@@ -22,6 +22,22 @@ import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .tools.disasm import extract
|
||||
|
||||
class enter_sub_region:
|
||||
def __init__(self, generator: CodeGenerator):
|
||||
self.generator = generator
|
||||
|
||||
def __enter__(self):
|
||||
# record lscope & local_defs in the parent scope
|
||||
self.liveins = self.generator.lscope.copy()
|
||||
self.prev_defs = self.generator.local_defs.copy()
|
||||
self.generator.local_defs = {}
|
||||
self.insert_block = self.generator.builder.get_insertion_block()
|
||||
return self.liveins, self.insert_block
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self.generator.builder.set_insertion_point_to_end(self.insert_block)
|
||||
self.generator.lscope = self.liveins
|
||||
self.generator.local_defs = self.prev_defs
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
||||
@@ -267,65 +283,57 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
cond = self.visit(node.test)
|
||||
if isinstance(cond, triton.language.tensor):
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
# record lscope & local_defs in the parent scope
|
||||
liveins = self.lscope.copy()
|
||||
parent_defs = self.local_defs.copy()
|
||||
self.local_defs = {}
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, ip_block = sr
|
||||
|
||||
ip_block = self.builder.get_insertion_block()
|
||||
then_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(then_block)
|
||||
self.visit_compound_statement(node.body)
|
||||
then_defs = self.local_defs.copy()
|
||||
|
||||
then_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(then_block)
|
||||
self.visit_compound_statement(node.body)
|
||||
then_defs = self.local_defs.copy()
|
||||
if then_defs or node.orelse:
|
||||
if node.orelse:
|
||||
self.local_defs = {}
|
||||
else_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_end(else_block)
|
||||
self.visit_compound_statement(node.orelse)
|
||||
else_defs = self.local_defs.copy()
|
||||
else:
|
||||
# collect else_defs
|
||||
else_defs = {}
|
||||
for name in then_defs:
|
||||
if name in liveins:
|
||||
# TODO: what if this is constexpr?
|
||||
assert self.is_triton_tensor(then_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
else_defs[name] = liveins[name]
|
||||
# collect yields
|
||||
names = []
|
||||
ret_types = []
|
||||
for then_name in then_defs:
|
||||
for else_name in else_defs:
|
||||
if then_name == else_name:
|
||||
if then_defs[then_name].type == else_defs[else_name].type:
|
||||
names.append(then_name)
|
||||
ret_types.append(then_defs[then_name].type)
|
||||
|
||||
if then_defs or node.orelse:
|
||||
if node.orelse:
|
||||
self.local_defs = {}
|
||||
else_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_end(else_block)
|
||||
self.visit_compound_statement(node.orelse)
|
||||
else_defs = self.local_defs.copy()
|
||||
self.builder.set_insertion_point_to_end(ip_block)
|
||||
|
||||
if then_defs or node.orelse:
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||
self.builder.create_yield_op([y.handle for n, y in then_defs.items()])
|
||||
if not node.orelse:
|
||||
else_block = if_op.get_else_block()
|
||||
else:
|
||||
else_block.merge_block_before(if_op.get_else_block())
|
||||
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
||||
self.builder.create_yield_op([y.handle for n, y in else_defs.items()])
|
||||
else:
|
||||
# collect else_defs
|
||||
else_defs = {}
|
||||
for name in then_defs:
|
||||
if name in liveins:
|
||||
# TODO: what if this is constexpr?
|
||||
assert self.is_triton_tensor(then_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
else_defs[name] = liveins[name]
|
||||
# collect yields
|
||||
names = []
|
||||
ret_types = []
|
||||
for then_name in then_defs:
|
||||
for else_name in else_defs:
|
||||
if then_name == else_name:
|
||||
if then_defs[then_name].type == else_defs[else_name].type:
|
||||
names.append(then_name)
|
||||
ret_types.append(then_defs[then_name].type)
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
|
||||
self.builder.set_insertion_point_to_end(ip_block)
|
||||
|
||||
if then_defs or node.orelse:
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||
self.builder.create_yield_op([y.handle for n, y in then_defs.items()])
|
||||
if not node.orelse:
|
||||
else_block = if_op.get_else_block()
|
||||
else:
|
||||
else_block.merge_block_before(if_op.get_else_block())
|
||||
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
||||
self.builder.create_yield_op([y.handle for n, y in else_defs.items()])
|
||||
else:
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
|
||||
self.builder.set_insertion_point_to_end(ip_block)
|
||||
# restore values in the parent scope
|
||||
self.lscope = liveins
|
||||
self.local_defs = parent_defs
|
||||
# update values yielded by IfOp
|
||||
for i, name in enumerate(names):
|
||||
new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i])
|
||||
@@ -396,70 +404,64 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return getattr(op, fn)()
|
||||
|
||||
def visit_While(self, node):
|
||||
liveins = self.lscope.copy()
|
||||
prev_defs = self.local_defs.copy()
|
||||
self.local_defs = {}
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, insert_block = sr
|
||||
|
||||
insert_block = self.builder.get_insertion_block()
|
||||
# condtion (the before region)
|
||||
cond_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(cond_block)
|
||||
cond = self.visit(node.test)
|
||||
|
||||
# condtion (the before region)
|
||||
cond_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(cond_block)
|
||||
cond = self.visit(node.test)
|
||||
# loop body (the after region)
|
||||
loop_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(loop_block)
|
||||
self.visit_compound_statement(node.body)
|
||||
loop_defs = self.local_defs
|
||||
|
||||
# loop body (the after region)
|
||||
loop_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(loop_block)
|
||||
self.visit_compound_statement(node.body)
|
||||
loop_defs = self.local_defs
|
||||
# collect loop-carried values
|
||||
names = []
|
||||
ret_types = []
|
||||
init_args = []
|
||||
yields = []
|
||||
for name in loop_defs:
|
||||
if name in liveins:
|
||||
# We should not def new constexpr (?)
|
||||
assert self.is_triton_tensor(loop_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
if loop_defs[name].type == liveins[name].type:
|
||||
# these are loop-carried values
|
||||
names.append(name)
|
||||
ret_types.append(loop_defs[name].type)
|
||||
init_args.append(liveins[name])
|
||||
yields.append(loop_defs[name])
|
||||
|
||||
# collect loop-carried values
|
||||
names = []
|
||||
ret_types = []
|
||||
init_args = []
|
||||
yields = []
|
||||
for name in loop_defs:
|
||||
if name in liveins:
|
||||
# We should not def new constexpr (?)
|
||||
assert self.is_triton_tensor(loop_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
if loop_defs[name].type == liveins[name].type:
|
||||
# these are loop-carried values
|
||||
names.append(name)
|
||||
ret_types.append(loop_defs[name].type)
|
||||
init_args.append(liveins[name])
|
||||
yields.append(loop_defs[name])
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
|
||||
[arg.handle for arg in init_args])
|
||||
# merge the condition region
|
||||
before_block = self.builder.create_block_with_parent(while_op.get_before(),
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
cond_block.merge_block_before(before_block)
|
||||
self.builder.set_insertion_point_to_end(before_block)
|
||||
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
||||
self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
|
||||
# merge the loop body
|
||||
after_block = self.builder.create_block_with_parent(while_op.get_after(),
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
loop_block.merge_block_before(after_block)
|
||||
self.builder.set_insertion_point_to_end(after_block)
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
|
||||
[arg.handle for arg in init_args])
|
||||
# merge the condition region
|
||||
before_block = self.builder.create_block_with_parent(while_op.get_before(),
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
cond_block.merge_block_before(before_block)
|
||||
self.builder.set_insertion_point_to_end(before_block)
|
||||
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
||||
self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
|
||||
# merge the loop body
|
||||
after_block = self.builder.create_block_with_parent(while_op.get_after(),
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
loop_block.merge_block_before(after_block)
|
||||
self.builder.set_insertion_point_to_end(after_block)
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
# update global uses in while_op
|
||||
for i, name in enumerate(names):
|
||||
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
||||
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
||||
|
||||
# update global uses in while_op
|
||||
for i, name in enumerate(names):
|
||||
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
||||
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
self.lscope = liveins
|
||||
self.local_defs = prev_defs
|
||||
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
||||
for i, name in enumerate(names):
|
||||
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
|
||||
self.lscope[name] = new_def
|
||||
self.local_defs[name] = new_def
|
||||
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
||||
for i, name in enumerate(names):
|
||||
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
|
||||
self.lscope[name] = new_def
|
||||
self.local_defs[name] = new_def
|
||||
|
||||
for stmt in node.orelse:
|
||||
assert False, "Not implemented"
|
||||
@@ -508,52 +510,43 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ub = self.builder.create_to_index(ub)
|
||||
step = self.builder.create_to_index(step)
|
||||
|
||||
# cache current insertion block
|
||||
insert_block = self.builder.get_insertion_block()
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, insert_block = sr
|
||||
|
||||
# create loop body block
|
||||
block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(block)
|
||||
# create loop body block
|
||||
block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(block)
|
||||
|
||||
# record lscope & local_defs in the parent scope
|
||||
liveins = self.lscope.copy()
|
||||
prev_defs = self.local_defs.copy()
|
||||
self.local_defs = {}
|
||||
# visit loop body
|
||||
self.visit_compound_statement(node.body)
|
||||
|
||||
# visit loop body
|
||||
self.visit_compound_statement(node.body)
|
||||
# If a variable (name) is defined in both its parent & itself, then it's
|
||||
# a loop-carried variable. (They must be of the same type)
|
||||
init_args = []
|
||||
yields = []
|
||||
names = []
|
||||
for name in self.local_defs:
|
||||
if name in liveins:
|
||||
assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor'
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
if self.local_defs[name].type == liveins[name].type:
|
||||
names.append(name)
|
||||
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
||||
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
||||
# create ForOp
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
|
||||
block.merge_block_before(for_op.get_body(0))
|
||||
# create YieldOp
|
||||
self.builder.set_insertion_point_to_end(for_op.get_body(0))
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
for_op_region = for_op.get_body(0).get_parent()
|
||||
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
|
||||
# replace global uses with block arguments
|
||||
for i, name in enumerate(names):
|
||||
# arg0 is the induction variable
|
||||
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
||||
|
||||
# If a variable (name) is defined in both its parent & itself, then it's
|
||||
# a loop-carried variable. (They must be of the same type)
|
||||
init_args = []
|
||||
yields = []
|
||||
names = []
|
||||
for name in self.local_defs:
|
||||
if name in liveins:
|
||||
assert self.is_triton_tensor(self.local_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
if self.local_defs[name].type == liveins[name].type:
|
||||
names.append(name)
|
||||
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
||||
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
||||
# create ForOp
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
|
||||
block.merge_block_before(for_op.get_body(0))
|
||||
# create YieldOp
|
||||
self.builder.set_insertion_point_to_end(for_op.get_body(0))
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
for_op_region = for_op.get_body(0).get_parent()
|
||||
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
|
||||
# replace global uses with block arguments
|
||||
for i, name in enumerate(names):
|
||||
# arg0 is the induction variable
|
||||
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
|
||||
self.lscope = liveins
|
||||
self.local_defs = prev_defs
|
||||
# update lscope & local_defs (ForOp defines new values)
|
||||
for i, name in enumerate(names):
|
||||
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||
|
Reference in New Issue
Block a user