From 9c7b3d51739b3264dc3afc2779dce949073608d5 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Sun, 10 Apr 2022 15:02:12 +0800 Subject: [PATCH] Manage insertion block with context manager --- python/triton/code_gen.py | 301 +++++++++++++++++++------------------- 1 file changed, 147 insertions(+), 154 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 916ba19b9..32c7ce7f5 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -22,6 +22,22 @@ import triton import triton._C.libtriton.triton as _triton from .tools.disasm import extract +class enter_sub_region: + def __init__(self, generator: CodeGenerator): + self.generator = generator + + def __enter__(self): + # record lscope & local_defs in the parent scope + self.liveins = self.generator.lscope.copy() + self.prev_defs = self.generator.local_defs.copy() + self.generator.local_defs = {} + self.insert_block = self.generator.builder.get_insertion_block() + return self.liveins, self.insert_block + + def __exit__(self, *args, **kwargs): + self.generator.builder.set_insertion_point_to_end(self.insert_block) + self.generator.lscope = self.liveins + self.generator.local_defs = self.prev_defs class CodeGenerator(ast.NodeVisitor): def __init__(self, context, prototype, gscope, attributes, constants, kwargs): @@ -267,65 +283,57 @@ class CodeGenerator(ast.NodeVisitor): cond = self.visit(node.test) if isinstance(cond, triton.language.tensor): cond = cond.to(triton.language.int1, _builder=self.builder) - # record lscope & local_defs in the parent scope - liveins = self.lscope.copy() - parent_defs = self.local_defs.copy() - self.local_defs = {} + with enter_sub_region(self) as sr: + liveins, ip_block = sr - ip_block = self.builder.get_insertion_block() + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_defs = self.local_defs.copy() - then_block = self.builder.create_block() - self.builder.set_insertion_point_to_start(then_block) - self.visit_compound_statement(node.body) - then_defs = self.local_defs.copy() + if then_defs or node.orelse: + if node.orelse: + self.local_defs = {} + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_end(else_block) + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else: + # collect else_defs + else_defs = {} + for name in then_defs: + if name in liveins: + # TODO: what if this is constexpr? + assert self.is_triton_tensor(then_defs[name]) + assert self.is_triton_tensor(liveins[name]) + else_defs[name] = liveins[name] + # collect yields + names = [] + ret_types = [] + for then_name in then_defs: + for else_name in else_defs: + if then_name == else_name: + if then_defs[then_name].type == else_defs[else_name].type: + names.append(then_name) + ret_types.append(then_defs[then_name].type) - if then_defs or node.orelse: - if node.orelse: - self.local_defs = {} - else_block = self.builder.create_block() - self.builder.set_insertion_point_to_end(else_block) - self.visit_compound_statement(node.orelse) - else_defs = self.local_defs.copy() + self.builder.set_insertion_point_to_end(ip_block) + + if then_defs or node.orelse: + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([y.handle for n, y in then_defs.items()]) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([y.handle for n, y in else_defs.items()]) else: - # collect else_defs - else_defs = {} - for name in then_defs: - if name in liveins: - # TODO: what if this is constexpr? - assert self.is_triton_tensor(then_defs[name]) - assert self.is_triton_tensor(liveins[name]) - else_defs[name] = liveins[name] - # collect yields - names = [] - ret_types = [] - for then_name in then_defs: - for else_name in else_defs: - if then_name == else_name: - if then_defs[then_name].type == else_defs[else_name].type: - names.append(then_name) - ret_types.append(then_defs[then_name].type) + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False) + then_block.merge_block_before(if_op.get_then_block()) - self.builder.set_insertion_point_to_end(ip_block) - - if then_defs or node.orelse: - if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) - then_block.merge_block_before(if_op.get_then_block()) - self.builder.set_insertion_point_to_end(if_op.get_then_block()) - self.builder.create_yield_op([y.handle for n, y in then_defs.items()]) - if not node.orelse: - else_block = if_op.get_else_block() - else: - else_block.merge_block_before(if_op.get_else_block()) - self.builder.set_insertion_point_to_end(if_op.get_else_block()) - self.builder.create_yield_op([y.handle for n, y in else_defs.items()]) - else: - if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False) - then_block.merge_block_before(if_op.get_then_block()) - - self.builder.set_insertion_point_to_end(ip_block) - # restore values in the parent scope - self.lscope = liveins - self.local_defs = parent_defs # update values yielded by IfOp for i, name in enumerate(names): new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i]) @@ -396,70 +404,64 @@ class CodeGenerator(ast.NodeVisitor): return getattr(op, fn)() def visit_While(self, node): - liveins = self.lscope.copy() - prev_defs = self.local_defs.copy() - self.local_defs = {} + with enter_sub_region(self) as sr: + liveins, insert_block = sr - insert_block = self.builder.get_insertion_block() + # condtion (the before region) + cond_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(cond_block) + cond = self.visit(node.test) - # condtion (the before region) - cond_block = self.builder.create_block() - self.builder.set_insertion_point_to_start(cond_block) - cond = self.visit(node.test) + # loop body (the after region) + loop_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(loop_block) + self.visit_compound_statement(node.body) + loop_defs = self.local_defs - # loop body (the after region) - loop_block = self.builder.create_block() - self.builder.set_insertion_point_to_start(loop_block) - self.visit_compound_statement(node.body) - loop_defs = self.local_defs + # collect loop-carried values + names = [] + ret_types = [] + init_args = [] + yields = [] + for name in loop_defs: + if name in liveins: + # We should not def new constexpr (?) + assert self.is_triton_tensor(loop_defs[name]) + assert self.is_triton_tensor(liveins[name]) + if loop_defs[name].type == liveins[name].type: + # these are loop-carried values + names.append(name) + ret_types.append(loop_defs[name].type) + init_args.append(liveins[name]) + yields.append(loop_defs[name]) - # collect loop-carried values - names = [] - ret_types = [] - init_args = [] - yields = [] - for name in loop_defs: - if name in liveins: - # We should not def new constexpr (?) - assert self.is_triton_tensor(loop_defs[name]) - assert self.is_triton_tensor(liveins[name]) - if loop_defs[name].type == liveins[name].type: - # these are loop-carried values - names.append(name) - ret_types.append(loop_defs[name].type) - init_args.append(liveins[name]) - yields.append(loop_defs[name]) + self.builder.set_insertion_point_to_end(insert_block) + while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], + [arg.handle for arg in init_args]) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before(), + [ty.to_ir(self.builder) for ty in ret_types]) + cond_block.merge_block_before(before_block) + self.builder.set_insertion_point_to_end(before_block) + # create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after(), + [ty.to_ir(self.builder) for ty in ret_types]) + loop_block.merge_block_before(after_block) + self.builder.set_insertion_point_to_end(after_block) + self.builder.create_yield_op([y.handle for y in yields]) - self.builder.set_insertion_point_to_end(insert_block) - while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], - [arg.handle for arg in init_args]) - # merge the condition region - before_block = self.builder.create_block_with_parent(while_op.get_before(), - [ty.to_ir(self.builder) for ty in ret_types]) - cond_block.merge_block_before(before_block) - self.builder.set_insertion_point_to_end(before_block) - # create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... - self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) - # merge the loop body - after_block = self.builder.create_block_with_parent(while_op.get_after(), - [ty.to_ir(self.builder) for ty in ret_types]) - loop_block.merge_block_before(after_block) - self.builder.set_insertion_point_to_end(after_block) - self.builder.create_yield_op([y.handle for y in yields]) + # update global uses in while_op + for i, name in enumerate(names): + before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i)) + after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i)) - # update global uses in while_op - for i, name in enumerate(names): - before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i)) - after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i)) - - self.builder.set_insertion_point_to_end(insert_block) - self.lscope = liveins - self.local_defs = prev_defs - # WhileOp defines new values, update the symbol table (lscope, local_defs) - for i, name in enumerate(names): - new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i]) - self.lscope[name] = new_def - self.local_defs[name] = new_def + # WhileOp defines new values, update the symbol table (lscope, local_defs) + for i, name in enumerate(names): + new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i]) + self.lscope[name] = new_def + self.local_defs[name] = new_def for stmt in node.orelse: assert False, "Not implemented" @@ -508,52 +510,43 @@ class CodeGenerator(ast.NodeVisitor): ub = self.builder.create_to_index(ub) step = self.builder.create_to_index(step) - # cache current insertion block - insert_block = self.builder.get_insertion_block() + with enter_sub_region(self) as sr: + liveins, insert_block = sr - # create loop body block - block = self.builder.create_block() - self.builder.set_insertion_point_to_start(block) + # create loop body block + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) - # record lscope & local_defs in the parent scope - liveins = self.lscope.copy() - prev_defs = self.local_defs.copy() - self.local_defs = {} + # visit loop body + self.visit_compound_statement(node.body) - # visit loop body - self.visit_compound_statement(node.body) + # If a variable (name) is defined in both its parent & itself, then it's + # a loop-carried variable. (They must be of the same type) + init_args = [] + yields = [] + names = [] + for name in self.local_defs: + if name in liveins: + assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor' + assert self.is_triton_tensor(liveins[name]) + if self.local_defs[name].type == liveins[name].type: + names.append(name) + init_args.append(triton.language.core._to_tensor(liveins[name], self.builder)) + yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder)) + # create ForOp + self.builder.set_insertion_point_to_end(insert_block) + for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) + block.merge_block_before(for_op.get_body(0)) + # create YieldOp + self.builder.set_insertion_point_to_end(for_op.get_body(0)) + self.builder.create_yield_op([y.handle for y in yields]) + for_op_region = for_op.get_body(0).get_parent() + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + # replace global uses with block arguments + for i, name in enumerate(names): + # arg0 is the induction variable + for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1)) - # If a variable (name) is defined in both its parent & itself, then it's - # a loop-carried variable. (They must be of the same type) - init_args = [] - yields = [] - names = [] - for name in self.local_defs: - if name in liveins: - assert self.is_triton_tensor(self.local_defs[name]) - assert self.is_triton_tensor(liveins[name]) - if self.local_defs[name].type == liveins[name].type: - names.append(name) - init_args.append(triton.language.core._to_tensor(liveins[name], self.builder)) - yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder)) - # create ForOp - self.builder.set_insertion_point_to_end(insert_block) - for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) - block.merge_block_before(for_op.get_body(0)) - # create YieldOp - self.builder.set_insertion_point_to_end(for_op.get_body(0)) - self.builder.create_yield_op([y.handle for y in yields]) - for_op_region = for_op.get_body(0).get_parent() - assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" - # replace global uses with block arguments - for i, name in enumerate(names): - # arg0 is the induction variable - for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1)) - - self.builder.set_insertion_point_to_end(insert_block) - - self.lscope = liveins - self.local_defs = prev_defs # update lscope & local_defs (ForOp defines new values) for i, name in enumerate(names): self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)