From 4ad432f1fc851ab1e39d2b47726f24380ab874c9 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Thu, 31 Mar 2022 21:42:48 +0800 Subject: [PATCH] More on scf Ops --- python/src/triton.cc | 40 ++++++---- python/triton/code_gen.py | 156 +++++++++++++++++++++++--------------- 2 files changed, 119 insertions(+), 77 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 201f122e1..8baffd066 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -711,6 +711,17 @@ void init_triton_ir(py::module &&m) { .def("dump", [](mlir::FuncOp &self) { self.dump(); }) ; + // Ops + py::class_(m, "ForOp") + .def("get_body", &mlir::scf::ForOp::getBody, ret::reference); + py::class_(m, "IfOp") + .def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference) + .def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference) + .def("get_then_yield", &mlir::scf::IfOp::thenYield) + .def("get_else_yield", &mlir::scf::IfOp::elseYield) + ; + py::class_(m, "YieldOp"); + py::class_(m, "builder", py::dynamic_attr()) .def(py::init()) // // getters @@ -839,21 +850,20 @@ void init_triton_ir(py::module &&m) { } throw std::runtime_error("invalid function type"); }) - // // Structured control flow - // .def("create_for", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub, - // mlir::Value &step, std::vector &initArgs) -> MlirOperation { - // auto loc = self.getUnknownLoc(); - // return wrap(self.create( - // loc, unwrap(lb), unwrap(ub), unwrap(step)).getOperation()); - // }) - // .def("create_if", [](mlir::OpBuilder &self, mlir::Value &condition) -> MlirOperation { - // auto loc = self.getUnknownLoc(); - // return wrap(self.create(loc, unwrap(condition)).getOperation()); - // }) - // .def("create_yield", [](mlir::OpBuilder &self) -> MlirOperation { - // auto loc = self.getUnknownLoc(); - // return wrap(self.create(loc).getOperation()); - // }) + // Structured control flow + .def("create_for_op", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub, + mlir::Value &step, std::vector &initArgs) -> mlir::scf::ForOp { + auto loc = self.getUnknownLoc(); + return self.create(loc, lb, ub, step); + }) + .def("create_if_of", [](mlir::OpBuilder &self, mlir::Value &condition) -> mlir::scf::IfOp { + auto loc = self.getUnknownLoc(); + return self.create(loc, condition); + }) + .def("create_yield_op", [](mlir::OpBuilder &self) -> mlir::scf::YieldOp { + auto loc = self.getUnknownLoc(); + return self.create(loc); + }) // // .def("create_while") // miscellious diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 06b19ce85..02ba697e7 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -362,30 +362,40 @@ class CodeGenerator(ast.NodeVisitor): def visit_If(self, node): cond = self.visit(node.test) if isinstance(cond, triton.language.tensor): - cond = cond.to(triton.language.int1, _builder=self.builder) - current_bb = self.builder.get_insertion_block() - then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) - else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None - endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) - self._seal_block(then_bb) - if else_bb: - self._seal_block(else_bb) - self.builder.cond_br(cond.handle, then_bb, else_bb) - else: - self.builder.cond_br(cond.handle, then_bb, endif_bb) - self.builder.set_insert_block(then_bb) - is_terminator = self.visit_compound_statement(node.body) - # TODO: last statement is a terminator? - if not is_terminator: - self.builder.br(endif_bb) - if else_bb: - self.builder.set_insert_block(else_bb) - is_terminator = self.visit_compound_statement(node.orelse) - # TODO: last statement is a terminator? - if not is_terminator: - self.builder.br(endif_bb) - self._seal_block(endif_bb) - self.builder.set_insert_block(endif_bb) + # cond = cond.to(triton.language.int1, _builder=self.builder) + # current_bb = self.builder.get_insertion_block() + # then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) + # else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None + # endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) + # self._seal_block(then_bb) + # if else_bb: + # self._seal_block(else_bb) + # self.builder.cond_br(cond.handle, then_bb, else_bb) + # else: + # self.builder.cond_br(cond.handle, then_bb, endif_bb) + # self.builder.set_insert_block(then_bb) + # is_terminator = self.visit_compound_statement(node.body) + # # TODO: last statement is a terminator? + # if not is_terminator: + # self.builder.br(endif_bb) + # if else_bb: + # self.builder.set_insert_block(else_bb) + # is_terminator = self.visit_compound_statement(node.orelse) + # # TODO: last statement is a terminator? + # if not is_terminator: + # self.builder.br(endif_bb) + # self._seal_block(endif_bb) + # self.builder.set_insert_block(endif_bb) + parent_lvalues = self.lvalues.copy() + self.visit_compound_statement(node.body) + then_lvalues = self.lvalues.copy() + assert node.orelse + self.lvalues = parent_lvalues + self.visit_compound_statement(node.orelse) + else_lvalues = self.lvalues.copy() + + self.lvalues = join_if_lvalues(then_lvalues, else_lvalues) + else: if isinstance(cond, triton.language.constexpr): cond = cond.value @@ -500,47 +510,69 @@ class CodeGenerator(ast.NodeVisitor): for stmt in node.orelse: ast.NodeVisitor.generic_visit(self, stmt) return - # create nodes - st_target = ast.Name(id=node.target.id, ctx=ast.Store()) - ld_target = ast.Name(id=node.target.id, ctx=ast.Load()) - arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0) - arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0] - arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1) - init_node = ast.Assign(targets=[st_target], value=arg_0) - pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1]) - neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1]) - pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)]) - build_cond = lambda: triton.language.where(self.visit(pos_step_node), - self.visit(pos_cond_node), - self.visit(neg_cond_node), - _builder=self.builder) - # cond_node = neg_cond_node - step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) - # code generation - current_bb = self.builder.get_insertion_block() - loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) - next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) + # # create nodes + # st_target = ast.Name(id=node.target.id, ctx=ast.Store()) + # ld_target = ast.Name(id=node.target.id, ctx=ast.Load()) + # arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0) + # arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0] + # arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1) + # init_node = ast.Assign(targets=[st_target], value=arg_0) + # pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1]) + # neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1]) + # pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)]) + # build_cond = lambda: triton.language.where(self.visit(pos_step_node), + # self.visit(pos_cond_node), + # self.visit(neg_cond_node), + # _builder=self.builder) + # # cond_node = neg_cond_node + # step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) + # # code generation + # current_bb = self.builder.get_insertion_block() + # loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) + # next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) - def continue_fn(): - self.visit(step_node) - cond = build_cond() - return self.builder.cond_br(cond.handle, loop_bb, next_bb) + # def continue_fn(): + # self.visit(step_node) + # cond = build_cond() + # return self.builder.cond_br(cond.handle, loop_bb, next_bb) - self.visit(init_node) - cond = build_cond() - self.builder.cond_br(cond.handle, loop_bb, next_bb) - self.builder.set_insert_block(loop_bb) - self.visit_compound_statement(node.body) - # TODO: handle case where body breaks control flow - continue_fn() - stop_bb = self.builder.get_insertion_block() - self._seal_block(stop_bb) - self._seal_block(loop_bb) - self._seal_block(next_bb) - self.builder.set_insert_block(next_bb) + # self.visit(init_node) + # cond = build_cond() + # self.builder.cond_br(cond.handle, loop_bb, next_bb) + # self.builder.set_insert_block(loop_bb) + # self.visit_compound_statement(node.body) + # # TODO: handle case where body breaks control flow + # continue_fn() + # stop_bb = self.builder.get_insertion_block() + # self._seal_block(stop_bb) + # self._seal_block(loop_bb) + # self._seal_block(next_bb) + # self.builder.set_insert_block(next_bb) - for stmt in node.orelse: - ast.NodeVisitor.generic_visit(self, stmt) + # for stmt in node.orelse: + # ast.NodeVisitor.generic_visit(self, stmt) + lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)) + ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]) + step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)) + + loop_body = self.builder.create_block() + self.builder.set_insertion_point_to_start(loop_body) + + prev_defs = self.local_defs.copy() + self.local_defs = set() + + # visit loop body + parent_lvalues = self.lvalues.copy() + self.visit_compound_statement() + loop_lvalues = self.lvalues.copy() + + # TODO: update insertion point + # TODO: create scf.forOp + # self.lvalues = join_loop_lvalues(parent_lvalues, loop_lvalues) + # self.make_for_op(parent_lvalues, loop_lvalues, lb, ub, step) + for_op = self.builder.create_for_op(lb, ub, step, [loop_init_args]) + + self.local_defs = prev_defs def visit_Slice(self, node): lower = self.visit(node.lower)