199 lines
7.8 KiB
C++
199 lines
7.8 KiB
C++
/* Copyright 2015-2017 Philippe Tillet
|
|
*
|
|
* Permission is hereby granted, free of charge, to any person obtaining
|
|
* a copy of this software and associated documentation files
|
|
* (the "Software"), to deal in the Software without restriction,
|
|
* including without limitation the rights to use, copy, modify, merge,
|
|
* publish, distribute, sublicense, and/or sell copies of the Software,
|
|
* and to permit persons to whom the Software is furnished to do so,
|
|
* subject to the following conditions:
|
|
*
|
|
* The above copyright notice and this permission notice shall be
|
|
* included in all copies or substantial portions of the Software.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*/
|
|
|
|
#include <assert.h>
|
|
#include <list>
|
|
#include <vector>
|
|
#include <stdexcept>
|
|
#include "isaac/types.h"
|
|
#include "isaac/array.h"
|
|
#include "isaac/runtime/profiles.h"
|
|
#include "isaac/runtime/execute.h"
|
|
#include "isaac/jit/syntax/expression/expression.h"
|
|
#include "isaac/jit/syntax/expression/preset.h"
|
|
|
|
namespace isaac
|
|
{
|
|
|
|
namespace runtime
|
|
{
|
|
|
|
namespace detail
|
|
{
|
|
|
|
inline bool is_elementwise(expression_type type)
|
|
{ return type == ELEMENTWISE_1D || type == ELEMENTWISE_2D; }
|
|
|
|
/** @brief Optimizes the given expression tree */
|
|
// expression_type optimize(expression_type & tree, size_t idx)
|
|
// {
|
|
// expression_tree::node & node = tree[idx];
|
|
// if(node.type==COMPOSITE_OPERATOR_TYPE)
|
|
// {
|
|
// //Remove useless reshape
|
|
// if(node.binary_operator.op.type==RESHAPE)
|
|
// }
|
|
// }
|
|
|
|
expression_type parse(expression_tree const & tree, breakpoints_t & bp){
|
|
return parse(tree, tree.root(), bp);
|
|
}
|
|
|
|
/** @brief Parses the breakpoints for a given expression tree */
|
|
expression_type parse(expression_tree const & tree, size_t idx, breakpoints_t & bp)
|
|
{
|
|
expression_tree::node const & node = tree[idx];
|
|
if(node.type==COMPOSITE_OPERATOR_TYPE)
|
|
{
|
|
size_t lidx = node.binary_operator.lhs;
|
|
size_t ridx = node.binary_operator.rhs;
|
|
expression_type ltype = parse(tree, lidx, bp);
|
|
expression_type rtype = parse(tree, ridx, bp);
|
|
op_element const & op = node.binary_operator.op;
|
|
//Reduction
|
|
if(op.type_family==REDUCE || op.type_family==REDUCE_ROWS || op.type_family==REDUCE_COLUMNS)
|
|
{
|
|
if(!is_elementwise(ltype)) bp.push_back({lidx, ltype});
|
|
if(!is_elementwise(rtype)) bp.push_back({ridx, rtype});
|
|
if(op.type_family==REDUCE) return REDUCE_1D;
|
|
if(op.type_family==REDUCE_ROWS) return REDUCE_2D_ROWS;
|
|
if(op.type_family==REDUCE_COLUMNS) return REDUCE_2D_COLS;
|
|
}
|
|
//Matrix Product
|
|
if(op.type_family==GEMM)
|
|
{
|
|
if(tree[lidx].type!=DENSE_ARRAY_TYPE) bp.push_back({lidx, ltype});
|
|
if(tree[ridx].type!=DENSE_ARRAY_TYPE) bp.push_back({ridx, rtype});
|
|
if(op.type==GEMM_NN_TYPE) return GEMM_NN;
|
|
if(op.type==GEMM_TN_TYPE) return GEMM_TN;
|
|
if(op.type==GEMM_NT_TYPE) return GEMM_NT;
|
|
if(op.type==GEMM_TT_TYPE) return GEMM_TT;
|
|
}
|
|
//Arithmetic
|
|
if(op.type_family==UNARY_ARITHMETIC || op.type_family==BINARY_ARITHMETIC)
|
|
{
|
|
//Non-elementwise kernels are temporaries when reshaped
|
|
if(op.type==RESHAPE_TYPE && !is_elementwise(ltype))
|
|
bp.push_back({lidx, ltype});
|
|
else
|
|
{
|
|
//Matrix-Products are temporaries when not assigned
|
|
for(expression_type type: std::vector<expression_type>{GEMM_NN,GEMM_TN,GEMM_NT,GEMM_TT})
|
|
{
|
|
if(ltype==type)
|
|
bp.push_back({lidx, ltype});
|
|
if(rtype==type && op.type!=ASSIGN_TYPE)
|
|
bp.push_back({ridx, rtype});
|
|
if(rtype==type && op.type==ASSIGN_TYPE)
|
|
return type;
|
|
}
|
|
//Reductions
|
|
for(expression_type type: std::vector<expression_type>{REDUCE_2D_ROWS, REDUCE_2D_COLS, REDUCE_1D})
|
|
{
|
|
if(ltype==type && ltype==rtype && tree[tree[lidx].binary_operator.lhs].shape == tree[tree[ridx].binary_operator.lhs].shape)
|
|
return type;
|
|
if(ltype==type && !is_elementwise(rtype))
|
|
bp.push_back({ridx, rtype});
|
|
if(!is_elementwise(ltype) && rtype==type)
|
|
bp.push_back({lidx, ltype});
|
|
if((ltype==type && rtype==ELEMENTWISE_1D) || (ltype==ELEMENTWISE_1D && rtype==type))
|
|
return type;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if(numgt1(node.shape)<=1)
|
|
return ELEMENTWISE_1D;
|
|
else
|
|
return ELEMENTWISE_2D;
|
|
}
|
|
}
|
|
|
|
/** @brief Executes a expression_tree on the given models map*/
|
|
void execute(execution_handler const & c, profiles::map_type & profiles)
|
|
{
|
|
typedef isaac::array array;
|
|
/*----Optimize----*/
|
|
// detail::optimize(tree);
|
|
/*----Process-----*/
|
|
expression_tree const & reftree = c.x();
|
|
driver::Context const & context = reftree.context();
|
|
size_t rootidx = reftree.root();
|
|
std::vector<std::shared_ptr<array> > temporaries;
|
|
expression_type final_type;
|
|
/*----Matrix Product-----*/
|
|
if(symbolic::preset::gemm::args args = symbolic::preset::gemm::check(reftree.data(), rootidx)){
|
|
final_type = args.type;
|
|
}
|
|
/*----Default-----*/
|
|
else
|
|
{
|
|
expression_tree tree = reftree;
|
|
expression_tree::node & root = tree[rootidx];
|
|
expression_tree::node & lhs = tree[root.binary_operator.lhs], &rhs = tree[root.binary_operator.rhs];
|
|
expression_tree::node root_save = root, lhs_save = lhs, rhs_save = rhs;
|
|
|
|
detail::breakpoints_t breakpoints;
|
|
breakpoints.reserve(16);
|
|
/*----Parse required temporaries-----*/
|
|
final_type = detail::parse(tree, breakpoints);
|
|
std::set<size_t> found;
|
|
breakpoints.erase(std::remove_if(breakpoints.begin(), breakpoints.end(), [&](detail::breakpoints_t::value_type const & x){return !found.insert(x.first).second;}), breakpoints.end());
|
|
/*----Compute required temporaries----*/
|
|
for(auto current: breakpoints)
|
|
{
|
|
expression_tree::node const & node = tree[current.first];
|
|
expression_type type = current.second;
|
|
std::shared_ptr<profiles::value_type> const & profile = profiles[std::make_pair(type, node.dtype)];
|
|
|
|
//Create temporary
|
|
std::shared_ptr<array> tmp = std::make_shared<array>(node.shape, node.dtype, context);
|
|
temporaries.push_back(tmp);
|
|
|
|
//Compute temporary
|
|
root.binary_operator.op.type = ASSIGN_TYPE;
|
|
root.shape = node.shape;
|
|
root.dtype = node.dtype;
|
|
lhs = expression_tree::node(*tmp);
|
|
rhs = node;
|
|
profile->execute(execution_handler(tree, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
|
//Update the expression tree
|
|
root = root_save;
|
|
lhs = lhs_save;
|
|
rhs = rhs_save;
|
|
tree[current.first] = expression_tree::node(*tmp);
|
|
}
|
|
}
|
|
|
|
/*-----Compute final expression-----*/
|
|
profiles[std::make_pair(final_type, reftree[rootidx].dtype)]->execute(execution_handler(reftree, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
|
}
|
|
|
|
void execute(execution_handler const & c)
|
|
{
|
|
execute(c, profiles::get(c.execution_options().queue(c.x().context())));
|
|
}
|
|
|
|
}
|
|
|
|
}
|