/* * Copyright (c) 2015, PHILIPPE TILLET. All rights reserved. * * This file is part of ISAAC. * * ISAAC is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation; either * version 2.1 of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, * MA 02110-1301 USA */ #include #include #include #include #include "isaac/types.h" #include "isaac/array.h" #include "isaac/runtime/profiles.h" #include "isaac/runtime/execute.h" #include "isaac/jit/syntax/expression/expression.h" #include "isaac/jit/syntax/expression/preset.h" namespace isaac { namespace runtime { namespace detail { inline bool is_elementwise(expression_type type) { return type == ELEMENTWISE_1D || type == ELEMENTWISE_2D; } /** @brief Optimizes the given expression tree */ // expression_type optimize(expression_type & tree, size_t idx) // { // expression_tree::node & node = tree[idx]; // if(node.type==COMPOSITE_OPERATOR_TYPE) // { // //Remove useless reshape // if(node.binary_operator.op.type==RESHAPE) // } // } expression_type parse(expression_tree const & tree, breakpoints_t & bp){ return parse(tree, tree.root(), bp); } /** @brief Parses the breakpoints for a given expression tree */ expression_type parse(expression_tree const & tree, size_t idx, breakpoints_t & bp) { expression_tree::node const & node = tree[idx]; if(node.type==COMPOSITE_OPERATOR_TYPE) { size_t lidx = node.binary_operator.lhs; size_t ridx = node.binary_operator.rhs; expression_type ltype = parse(tree, lidx, bp); expression_type rtype = parse(tree, ridx, bp); op_element const & op = node.binary_operator.op; //Reduction if(op.type_family==REDUCE || op.type_family==REDUCE_ROWS || op.type_family==REDUCE_COLUMNS) { if(!is_elementwise(ltype)) bp.push_back({lidx, ltype}); if(!is_elementwise(rtype)) bp.push_back({ridx, rtype}); if(op.type_family==REDUCE) return REDUCE_1D; if(op.type_family==REDUCE_ROWS) return REDUCE_2D_ROWS; if(op.type_family==REDUCE_COLUMNS) return REDUCE_2D_COLS; } //Matrix Product if(op.type_family==GEMM) { if(tree[lidx].type!=DENSE_ARRAY_TYPE) bp.push_back({lidx, ltype}); if(tree[ridx].type!=DENSE_ARRAY_TYPE) bp.push_back({ridx, rtype}); if(op.type==GEMM_NN_TYPE) return GEMM_NN; if(op.type==GEMM_TN_TYPE) return GEMM_TN; if(op.type==GEMM_NT_TYPE) return GEMM_NT; if(op.type==GEMM_TT_TYPE) return GEMM_TT; } //Arithmetic if(op.type_family==UNARY_ARITHMETIC || op.type_family==BINARY_ARITHMETIC) { //Non-elementwise kernels are temporaries when reshaped if(op.type==RESHAPE_TYPE && !is_elementwise(ltype)) bp.push_back({lidx, ltype}); else { //Matrix-Products are temporaries when not assigned for(expression_type type: std::vector{GEMM_NN,GEMM_TN,GEMM_NT,GEMM_TT}) { if(ltype==type) bp.push_back({lidx, ltype}); if(rtype==type && op.type!=ASSIGN_TYPE) bp.push_back({ridx, rtype}); if(rtype==type && op.type==ASSIGN_TYPE) return type; } //Reductions for(expression_type type: std::vector{REDUCE_2D_ROWS, REDUCE_2D_COLS, REDUCE_1D}) { if(ltype==type && ltype==rtype && tree[tree[lidx].binary_operator.lhs].shape == tree[tree[ridx].binary_operator.lhs].shape) return type; if(ltype==type && !is_elementwise(rtype)) bp.push_back({ridx, rtype}); if(!is_elementwise(ltype) && rtype==type) bp.push_back({lidx, ltype}); if((ltype==type && rtype==ELEMENTWISE_1D) || (ltype==ELEMENTWISE_1D && rtype==type)) return type; } } } } if(numgt1(node.shape)<=1) return ELEMENTWISE_1D; else return ELEMENTWISE_2D; } } /** @brief Executes a expression_tree on the given models map*/ void execute(execution_handler const & c, profiles::map_type & profiles) { typedef isaac::array array; expression_tree tree = c.x(); /*----Optimize----*/ // detail::optimize(tree); /*----Process-----*/ driver::Context const & context = tree.context(); size_t rootidx = tree.root(); std::vector > temporaries; expression_type final_type; /*----Matrix Product-----*/ if(symbolic::preset::gemm::args args = symbolic::preset::gemm::check(tree.data(), rootidx)){ final_type = args.type; } /*----Default-----*/ else { expression_tree::node & root = tree[rootidx]; expression_tree::node & lhs = tree[root.binary_operator.lhs], &rhs = tree[root.binary_operator.rhs]; expression_tree::node root_save = root, lhs_save = lhs, rhs_save = rhs; detail::breakpoints_t breakpoints; breakpoints.reserve(16); /*----Parse required temporaries-----*/ final_type = detail::parse(tree, breakpoints); std::set found; breakpoints.erase(std::remove_if(breakpoints.begin(), breakpoints.end(), [&](detail::breakpoints_t::value_type const & x){return !found.insert(x.first).second;}), breakpoints.end()); /*----Compute required temporaries----*/ for(auto current: breakpoints) { expression_tree::node const & node = tree[current.first]; expression_type type = current.second; std::shared_ptr const & profile = profiles[std::make_pair(type, node.dtype)]; //Create temporary std::shared_ptr tmp = std::make_shared(node.shape, node.dtype, context); temporaries.push_back(tmp); //Compute temporary root.binary_operator.op.type = ASSIGN_TYPE; root.shape = node.shape; root.dtype = node.dtype; lhs = expression_tree::node(*tmp); rhs = node; profile->execute(execution_handler(tree, c.execution_options(), c.dispatcher_options(), c.compilation_options())); //Update the expression tree root = root_save; lhs = lhs_save; rhs = rhs_save; tree[current.first] = expression_tree::node(*tmp); } } /*-----Compute final expression-----*/ profiles[std::make_pair(final_type, tree[rootidx].dtype)]->execute(execution_handler(tree, c.execution_options(), c.dispatcher_options(), c.compilation_options())); } void execute(execution_handler const & c) { execute(c, profiles::get(c.execution_options().queue(c.x().context()))); } } }