Manage insertion block with context manager
This commit is contained in:
@@ -22,6 +22,22 @@ import triton
|
|||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
from .tools.disasm import extract
|
from .tools.disasm import extract
|
||||||
|
|
||||||
|
class enter_sub_region:
|
||||||
|
def __init__(self, generator: CodeGenerator):
|
||||||
|
self.generator = generator
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
# record lscope & local_defs in the parent scope
|
||||||
|
self.liveins = self.generator.lscope.copy()
|
||||||
|
self.prev_defs = self.generator.local_defs.copy()
|
||||||
|
self.generator.local_defs = {}
|
||||||
|
self.insert_block = self.generator.builder.get_insertion_block()
|
||||||
|
return self.liveins, self.insert_block
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
self.generator.builder.set_insertion_point_to_end(self.insert_block)
|
||||||
|
self.generator.lscope = self.liveins
|
||||||
|
self.generator.local_defs = self.prev_defs
|
||||||
|
|
||||||
class CodeGenerator(ast.NodeVisitor):
|
class CodeGenerator(ast.NodeVisitor):
|
||||||
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
||||||
@@ -267,65 +283,57 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
cond = self.visit(node.test)
|
cond = self.visit(node.test)
|
||||||
if isinstance(cond, triton.language.tensor):
|
if isinstance(cond, triton.language.tensor):
|
||||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||||
# record lscope & local_defs in the parent scope
|
with enter_sub_region(self) as sr:
|
||||||
liveins = self.lscope.copy()
|
liveins, ip_block = sr
|
||||||
parent_defs = self.local_defs.copy()
|
|
||||||
self.local_defs = {}
|
|
||||||
|
|
||||||
ip_block = self.builder.get_insertion_block()
|
then_block = self.builder.create_block()
|
||||||
|
self.builder.set_insertion_point_to_start(then_block)
|
||||||
|
self.visit_compound_statement(node.body)
|
||||||
|
then_defs = self.local_defs.copy()
|
||||||
|
|
||||||
then_block = self.builder.create_block()
|
if then_defs or node.orelse:
|
||||||
self.builder.set_insertion_point_to_start(then_block)
|
if node.orelse:
|
||||||
self.visit_compound_statement(node.body)
|
self.local_defs = {}
|
||||||
then_defs = self.local_defs.copy()
|
else_block = self.builder.create_block()
|
||||||
|
self.builder.set_insertion_point_to_end(else_block)
|
||||||
|
self.visit_compound_statement(node.orelse)
|
||||||
|
else_defs = self.local_defs.copy()
|
||||||
|
else:
|
||||||
|
# collect else_defs
|
||||||
|
else_defs = {}
|
||||||
|
for name in then_defs:
|
||||||
|
if name in liveins:
|
||||||
|
# TODO: what if this is constexpr?
|
||||||
|
assert self.is_triton_tensor(then_defs[name])
|
||||||
|
assert self.is_triton_tensor(liveins[name])
|
||||||
|
else_defs[name] = liveins[name]
|
||||||
|
# collect yields
|
||||||
|
names = []
|
||||||
|
ret_types = []
|
||||||
|
for then_name in then_defs:
|
||||||
|
for else_name in else_defs:
|
||||||
|
if then_name == else_name:
|
||||||
|
if then_defs[then_name].type == else_defs[else_name].type:
|
||||||
|
names.append(then_name)
|
||||||
|
ret_types.append(then_defs[then_name].type)
|
||||||
|
|
||||||
if then_defs or node.orelse:
|
self.builder.set_insertion_point_to_end(ip_block)
|
||||||
if node.orelse:
|
|
||||||
self.local_defs = {}
|
if then_defs or node.orelse:
|
||||||
else_block = self.builder.create_block()
|
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
||||||
self.builder.set_insertion_point_to_end(else_block)
|
then_block.merge_block_before(if_op.get_then_block())
|
||||||
self.visit_compound_statement(node.orelse)
|
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||||
else_defs = self.local_defs.copy()
|
self.builder.create_yield_op([y.handle for n, y in then_defs.items()])
|
||||||
|
if not node.orelse:
|
||||||
|
else_block = if_op.get_else_block()
|
||||||
|
else:
|
||||||
|
else_block.merge_block_before(if_op.get_else_block())
|
||||||
|
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
||||||
|
self.builder.create_yield_op([y.handle for n, y in else_defs.items()])
|
||||||
else:
|
else:
|
||||||
# collect else_defs
|
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
||||||
else_defs = {}
|
then_block.merge_block_before(if_op.get_then_block())
|
||||||
for name in then_defs:
|
|
||||||
if name in liveins:
|
|
||||||
# TODO: what if this is constexpr?
|
|
||||||
assert self.is_triton_tensor(then_defs[name])
|
|
||||||
assert self.is_triton_tensor(liveins[name])
|
|
||||||
else_defs[name] = liveins[name]
|
|
||||||
# collect yields
|
|
||||||
names = []
|
|
||||||
ret_types = []
|
|
||||||
for then_name in then_defs:
|
|
||||||
for else_name in else_defs:
|
|
||||||
if then_name == else_name:
|
|
||||||
if then_defs[then_name].type == else_defs[else_name].type:
|
|
||||||
names.append(then_name)
|
|
||||||
ret_types.append(then_defs[then_name].type)
|
|
||||||
|
|
||||||
self.builder.set_insertion_point_to_end(ip_block)
|
|
||||||
|
|
||||||
if then_defs or node.orelse:
|
|
||||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
|
||||||
then_block.merge_block_before(if_op.get_then_block())
|
|
||||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
|
||||||
self.builder.create_yield_op([y.handle for n, y in then_defs.items()])
|
|
||||||
if not node.orelse:
|
|
||||||
else_block = if_op.get_else_block()
|
|
||||||
else:
|
|
||||||
else_block.merge_block_before(if_op.get_else_block())
|
|
||||||
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
|
||||||
self.builder.create_yield_op([y.handle for n, y in else_defs.items()])
|
|
||||||
else:
|
|
||||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
|
||||||
then_block.merge_block_before(if_op.get_then_block())
|
|
||||||
|
|
||||||
self.builder.set_insertion_point_to_end(ip_block)
|
|
||||||
# restore values in the parent scope
|
|
||||||
self.lscope = liveins
|
|
||||||
self.local_defs = parent_defs
|
|
||||||
# update values yielded by IfOp
|
# update values yielded by IfOp
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i])
|
new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i])
|
||||||
@@ -396,70 +404,64 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
return getattr(op, fn)()
|
return getattr(op, fn)()
|
||||||
|
|
||||||
def visit_While(self, node):
|
def visit_While(self, node):
|
||||||
liveins = self.lscope.copy()
|
with enter_sub_region(self) as sr:
|
||||||
prev_defs = self.local_defs.copy()
|
liveins, insert_block = sr
|
||||||
self.local_defs = {}
|
|
||||||
|
|
||||||
insert_block = self.builder.get_insertion_block()
|
# condtion (the before region)
|
||||||
|
cond_block = self.builder.create_block()
|
||||||
|
self.builder.set_insertion_point_to_start(cond_block)
|
||||||
|
cond = self.visit(node.test)
|
||||||
|
|
||||||
# condtion (the before region)
|
# loop body (the after region)
|
||||||
cond_block = self.builder.create_block()
|
loop_block = self.builder.create_block()
|
||||||
self.builder.set_insertion_point_to_start(cond_block)
|
self.builder.set_insertion_point_to_start(loop_block)
|
||||||
cond = self.visit(node.test)
|
self.visit_compound_statement(node.body)
|
||||||
|
loop_defs = self.local_defs
|
||||||
|
|
||||||
# loop body (the after region)
|
# collect loop-carried values
|
||||||
loop_block = self.builder.create_block()
|
names = []
|
||||||
self.builder.set_insertion_point_to_start(loop_block)
|
ret_types = []
|
||||||
self.visit_compound_statement(node.body)
|
init_args = []
|
||||||
loop_defs = self.local_defs
|
yields = []
|
||||||
|
for name in loop_defs:
|
||||||
|
if name in liveins:
|
||||||
|
# We should not def new constexpr (?)
|
||||||
|
assert self.is_triton_tensor(loop_defs[name])
|
||||||
|
assert self.is_triton_tensor(liveins[name])
|
||||||
|
if loop_defs[name].type == liveins[name].type:
|
||||||
|
# these are loop-carried values
|
||||||
|
names.append(name)
|
||||||
|
ret_types.append(loop_defs[name].type)
|
||||||
|
init_args.append(liveins[name])
|
||||||
|
yields.append(loop_defs[name])
|
||||||
|
|
||||||
# collect loop-carried values
|
self.builder.set_insertion_point_to_end(insert_block)
|
||||||
names = []
|
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
|
||||||
ret_types = []
|
[arg.handle for arg in init_args])
|
||||||
init_args = []
|
# merge the condition region
|
||||||
yields = []
|
before_block = self.builder.create_block_with_parent(while_op.get_before(),
|
||||||
for name in loop_defs:
|
[ty.to_ir(self.builder) for ty in ret_types])
|
||||||
if name in liveins:
|
cond_block.merge_block_before(before_block)
|
||||||
# We should not def new constexpr (?)
|
self.builder.set_insertion_point_to_end(before_block)
|
||||||
assert self.is_triton_tensor(loop_defs[name])
|
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
||||||
assert self.is_triton_tensor(liveins[name])
|
self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
|
||||||
if loop_defs[name].type == liveins[name].type:
|
# merge the loop body
|
||||||
# these are loop-carried values
|
after_block = self.builder.create_block_with_parent(while_op.get_after(),
|
||||||
names.append(name)
|
[ty.to_ir(self.builder) for ty in ret_types])
|
||||||
ret_types.append(loop_defs[name].type)
|
loop_block.merge_block_before(after_block)
|
||||||
init_args.append(liveins[name])
|
self.builder.set_insertion_point_to_end(after_block)
|
||||||
yields.append(loop_defs[name])
|
self.builder.create_yield_op([y.handle for y in yields])
|
||||||
|
|
||||||
self.builder.set_insertion_point_to_end(insert_block)
|
# update global uses in while_op
|
||||||
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
|
for i, name in enumerate(names):
|
||||||
[arg.handle for arg in init_args])
|
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
||||||
# merge the condition region
|
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
||||||
before_block = self.builder.create_block_with_parent(while_op.get_before(),
|
|
||||||
[ty.to_ir(self.builder) for ty in ret_types])
|
|
||||||
cond_block.merge_block_before(before_block)
|
|
||||||
self.builder.set_insertion_point_to_end(before_block)
|
|
||||||
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
|
||||||
self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
|
|
||||||
# merge the loop body
|
|
||||||
after_block = self.builder.create_block_with_parent(while_op.get_after(),
|
|
||||||
[ty.to_ir(self.builder) for ty in ret_types])
|
|
||||||
loop_block.merge_block_before(after_block)
|
|
||||||
self.builder.set_insertion_point_to_end(after_block)
|
|
||||||
self.builder.create_yield_op([y.handle for y in yields])
|
|
||||||
|
|
||||||
# update global uses in while_op
|
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
|
||||||
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
self.lscope[name] = new_def
|
||||||
|
self.local_defs[name] = new_def
|
||||||
self.builder.set_insertion_point_to_end(insert_block)
|
|
||||||
self.lscope = liveins
|
|
||||||
self.local_defs = prev_defs
|
|
||||||
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
|
||||||
for i, name in enumerate(names):
|
|
||||||
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
|
|
||||||
self.lscope[name] = new_def
|
|
||||||
self.local_defs[name] = new_def
|
|
||||||
|
|
||||||
for stmt in node.orelse:
|
for stmt in node.orelse:
|
||||||
assert False, "Not implemented"
|
assert False, "Not implemented"
|
||||||
@@ -508,52 +510,43 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ub = self.builder.create_to_index(ub)
|
ub = self.builder.create_to_index(ub)
|
||||||
step = self.builder.create_to_index(step)
|
step = self.builder.create_to_index(step)
|
||||||
|
|
||||||
# cache current insertion block
|
with enter_sub_region(self) as sr:
|
||||||
insert_block = self.builder.get_insertion_block()
|
liveins, insert_block = sr
|
||||||
|
|
||||||
# create loop body block
|
# create loop body block
|
||||||
block = self.builder.create_block()
|
block = self.builder.create_block()
|
||||||
self.builder.set_insertion_point_to_start(block)
|
self.builder.set_insertion_point_to_start(block)
|
||||||
|
|
||||||
# record lscope & local_defs in the parent scope
|
# visit loop body
|
||||||
liveins = self.lscope.copy()
|
self.visit_compound_statement(node.body)
|
||||||
prev_defs = self.local_defs.copy()
|
|
||||||
self.local_defs = {}
|
|
||||||
|
|
||||||
# visit loop body
|
# If a variable (name) is defined in both its parent & itself, then it's
|
||||||
self.visit_compound_statement(node.body)
|
# a loop-carried variable. (They must be of the same type)
|
||||||
|
init_args = []
|
||||||
|
yields = []
|
||||||
|
names = []
|
||||||
|
for name in self.local_defs:
|
||||||
|
if name in liveins:
|
||||||
|
assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor'
|
||||||
|
assert self.is_triton_tensor(liveins[name])
|
||||||
|
if self.local_defs[name].type == liveins[name].type:
|
||||||
|
names.append(name)
|
||||||
|
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
||||||
|
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
||||||
|
# create ForOp
|
||||||
|
self.builder.set_insertion_point_to_end(insert_block)
|
||||||
|
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
|
||||||
|
block.merge_block_before(for_op.get_body(0))
|
||||||
|
# create YieldOp
|
||||||
|
self.builder.set_insertion_point_to_end(for_op.get_body(0))
|
||||||
|
self.builder.create_yield_op([y.handle for y in yields])
|
||||||
|
for_op_region = for_op.get_body(0).get_parent()
|
||||||
|
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
|
||||||
|
# replace global uses with block arguments
|
||||||
|
for i, name in enumerate(names):
|
||||||
|
# arg0 is the induction variable
|
||||||
|
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
||||||
|
|
||||||
# If a variable (name) is defined in both its parent & itself, then it's
|
|
||||||
# a loop-carried variable. (They must be of the same type)
|
|
||||||
init_args = []
|
|
||||||
yields = []
|
|
||||||
names = []
|
|
||||||
for name in self.local_defs:
|
|
||||||
if name in liveins:
|
|
||||||
assert self.is_triton_tensor(self.local_defs[name])
|
|
||||||
assert self.is_triton_tensor(liveins[name])
|
|
||||||
if self.local_defs[name].type == liveins[name].type:
|
|
||||||
names.append(name)
|
|
||||||
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
|
||||||
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
|
||||||
# create ForOp
|
|
||||||
self.builder.set_insertion_point_to_end(insert_block)
|
|
||||||
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
|
|
||||||
block.merge_block_before(for_op.get_body(0))
|
|
||||||
# create YieldOp
|
|
||||||
self.builder.set_insertion_point_to_end(for_op.get_body(0))
|
|
||||||
self.builder.create_yield_op([y.handle for y in yields])
|
|
||||||
for_op_region = for_op.get_body(0).get_parent()
|
|
||||||
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
|
|
||||||
# replace global uses with block arguments
|
|
||||||
for i, name in enumerate(names):
|
|
||||||
# arg0 is the induction variable
|
|
||||||
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
|
||||||
|
|
||||||
self.builder.set_insertion_point_to_end(insert_block)
|
|
||||||
|
|
||||||
self.lscope = liveins
|
|
||||||
self.local_defs = prev_defs
|
|
||||||
# update lscope & local_defs (ForOp defines new values)
|
# update lscope & local_defs (ForOp defines new values)
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||||
|
Reference in New Issue
Block a user