2015-05-04 21:26:09 -04:00
|
|
|
#include <assert.h>
|
2015-01-12 13:20:53 -05:00
|
|
|
#include <list>
|
|
|
|
#include <vector>
|
2015-01-28 17:08:39 -05:00
|
|
|
#include <stdexcept>
|
2015-04-29 15:50:57 -04:00
|
|
|
#include "isaac/types.h"
|
|
|
|
#include "isaac/array.h"
|
2015-08-12 00:46:51 -07:00
|
|
|
#include "isaac/profiles/profiles.h"
|
2015-04-29 15:50:57 -04:00
|
|
|
#include "isaac/symbolic/expression.h"
|
2015-06-27 17:55:01 -07:00
|
|
|
#include "isaac/symbolic/preset.h"
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-04-29 15:50:57 -04:00
|
|
|
namespace isaac
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
|
|
|
|
namespace detail
|
|
|
|
{
|
2015-12-19 02:38:32 -05:00
|
|
|
typedef std::vector<std::pair<expression_type, tree_node*> > breakpoints_t;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
|
2015-06-28 00:03:03 -07:00
|
|
|
inline bool is_mmprod(expression_type x)
|
|
|
|
{
|
2015-12-19 00:20:27 -05:00
|
|
|
return x==MATRIX_PRODUCT_NN || x==MATRIX_PRODUCT_NT ||
|
|
|
|
x==MATRIX_PRODUCT_TN || x==MATRIX_PRODUCT_TT;
|
2015-06-28 00:03:03 -07:00
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-06-28 00:03:03 -07:00
|
|
|
inline bool is_mvprod(expression_type x)
|
|
|
|
{
|
2015-12-19 00:20:27 -05:00
|
|
|
return x==REDUCE_2D_ROWS || x==REDUCE_2D_COLS;
|
2015-06-28 00:03:03 -07:00
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-06-28 00:03:03 -07:00
|
|
|
inline bool has_temporary_impl(op_element op, expression_type expression, expression_type other, bool is_first)
|
|
|
|
{
|
|
|
|
bool result = false;
|
|
|
|
switch(op.type_family)
|
|
|
|
{
|
2015-12-19 02:04:39 -05:00
|
|
|
case UNARY_TYPE_FAMILY:
|
|
|
|
case BINARY_TYPE_FAMILY:
|
2015-06-28 00:03:03 -07:00
|
|
|
result |= is_mmprod(expression)
|
2015-12-19 00:20:27 -05:00
|
|
|
|| (result |= expression==REDUCE_2D_ROWS && other==REDUCE_2D_COLS)
|
|
|
|
|| (result |= expression==REDUCE_2D_COLS && other==REDUCE_2D_ROWS);
|
2015-06-28 00:03:03 -07:00
|
|
|
break;
|
2015-12-19 02:04:39 -05:00
|
|
|
case VECTOR_DOT_TYPE_FAMILY:
|
2015-06-28 00:03:03 -07:00
|
|
|
result |= is_mvprod(expression)
|
2015-12-19 00:20:27 -05:00
|
|
|
|| expression==REDUCE_1D;
|
2015-06-28 00:03:03 -07:00
|
|
|
break;
|
2015-12-19 02:04:39 -05:00
|
|
|
case ROWS_DOT_TYPE_FAMILY:
|
2015-06-28 00:03:03 -07:00
|
|
|
result |= is_mmprod(expression)
|
|
|
|
|| is_mvprod(expression)
|
2015-12-19 00:20:27 -05:00
|
|
|
|| expression==REDUCE_1D;
|
2015-06-28 00:03:03 -07:00
|
|
|
break;
|
2015-12-19 02:04:39 -05:00
|
|
|
case COLUMNS_DOT_TYPE_FAMILY:
|
2015-06-28 00:03:03 -07:00
|
|
|
result |= is_mmprod(expression)
|
|
|
|
|| is_mvprod(expression)
|
2015-12-19 00:20:27 -05:00
|
|
|
|| expression==REDUCE_1D;
|
2015-06-28 00:03:03 -07:00
|
|
|
break;
|
2015-12-19 02:04:39 -05:00
|
|
|
case MATRIX_PRODUCT_TYPE_FAMILY:
|
2015-06-28 00:03:03 -07:00
|
|
|
result |= (is_mmprod(expression) && !is_first)
|
|
|
|
|| is_mvprod(expression)
|
2015-12-19 00:20:27 -05:00
|
|
|
|| expression==REDUCE_1D;
|
2015-06-28 00:03:03 -07:00
|
|
|
break;
|
|
|
|
default:
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
return result;
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-06-28 00:03:03 -07:00
|
|
|
inline std::pair<bool, bool> has_temporary(op_element op, expression_type left, expression_type right, bool is_first)
|
|
|
|
{
|
|
|
|
bool has_temporary_left = has_temporary_impl(op, left, right, is_first);
|
|
|
|
bool has_temporary_right = has_temporary_impl(op, right, left, is_first);
|
|
|
|
return std::make_pair(has_temporary_left, has_temporary_right);
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-06-28 00:03:03 -07:00
|
|
|
inline expression_type merge(op_element op, expression_type left, expression_type right)
|
|
|
|
{
|
|
|
|
switch(op.type_family)
|
|
|
|
{
|
2015-12-19 02:04:39 -05:00
|
|
|
case UNARY_TYPE_FAMILY:
|
2015-06-28 00:03:03 -07:00
|
|
|
if(is_mmprod(left))
|
2015-12-19 00:20:27 -05:00
|
|
|
return ELEMENTWISE_2D;
|
2015-06-28 00:03:03 -07:00
|
|
|
return left;
|
2015-12-19 02:04:39 -05:00
|
|
|
case BINARY_TYPE_FAMILY:
|
2015-12-19 00:20:27 -05:00
|
|
|
if(left == REDUCE_2D_ROWS || right == REDUCE_2D_ROWS) return REDUCE_2D_ROWS;
|
|
|
|
else if(left == REDUCE_2D_COLS || right == REDUCE_2D_COLS) return REDUCE_2D_COLS;
|
|
|
|
else if(left == REDUCE_1D || right == REDUCE_1D) return REDUCE_1D;
|
|
|
|
else if(left == ELEMENTWISE_2D || right == ELEMENTWISE_2D) return ELEMENTWISE_2D;
|
2015-12-19 02:04:39 -05:00
|
|
|
else if(left == ELEMENTWISE_1D || right == ELEMENTWISE_1D) return op.type==OUTER_PROD_TYPE?ELEMENTWISE_2D:ELEMENTWISE_1D;
|
2015-12-19 00:20:27 -05:00
|
|
|
else if(is_mmprod(left) || is_mmprod(right)) return ELEMENTWISE_2D;
|
2015-09-30 15:31:41 -04:00
|
|
|
else if(right == INVALID_EXPRESSION_TYPE) return left;
|
|
|
|
else if(left == INVALID_EXPRESSION_TYPE) return right;
|
2015-06-28 00:03:03 -07:00
|
|
|
throw;
|
2015-12-19 02:04:39 -05:00
|
|
|
case VECTOR_DOT_TYPE_FAMILY:
|
2015-12-19 00:20:27 -05:00
|
|
|
return REDUCE_1D;
|
2015-12-19 02:04:39 -05:00
|
|
|
case ROWS_DOT_TYPE_FAMILY:
|
2015-12-19 00:20:27 -05:00
|
|
|
return REDUCE_2D_ROWS;
|
2015-12-19 02:04:39 -05:00
|
|
|
case COLUMNS_DOT_TYPE_FAMILY:
|
2015-12-19 00:20:27 -05:00
|
|
|
return REDUCE_2D_COLS;
|
2015-12-19 02:04:39 -05:00
|
|
|
case MATRIX_PRODUCT_TYPE_FAMILY:
|
|
|
|
if(op.type==MATRIX_PRODUCT_NN_TYPE) return MATRIX_PRODUCT_NN;
|
|
|
|
else if(op.type==MATRIX_PRODUCT_TN_TYPE) return MATRIX_PRODUCT_TN;
|
|
|
|
else if(op.type==MATRIX_PRODUCT_NT_TYPE) return MATRIX_PRODUCT_NT;
|
2015-12-19 00:20:27 -05:00
|
|
|
else return MATRIX_PRODUCT_TT;
|
2015-06-28 00:03:03 -07:00
|
|
|
default:
|
|
|
|
throw;
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
/** @brief Parses the breakpoints for a given expression tree */
|
2015-12-19 02:55:24 -05:00
|
|
|
static void parse(expression_tree::container_type&array, size_t idx,
|
2015-01-12 13:20:53 -05:00
|
|
|
breakpoints_t & breakpoints,
|
2015-04-29 15:50:57 -04:00
|
|
|
expression_type & final_type,
|
|
|
|
bool is_first = true)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-12-19 02:55:24 -05:00
|
|
|
expression_tree::node & node = array[idx];
|
2015-06-28 00:03:03 -07:00
|
|
|
|
2015-11-25 18:42:25 -05:00
|
|
|
auto ng1 = [](shape_t const & shape){ size_t res = 0 ; for(size_t i = 0 ; i < shape.size() ; ++i) res += (shape[i] > 1); return res;};
|
2015-06-28 00:03:03 -07:00
|
|
|
//Left
|
|
|
|
expression_type type_left = INVALID_EXPRESSION_TYPE;
|
2015-12-19 02:04:39 -05:00
|
|
|
if (node.lhs.subtype == COMPOSITE_OPERATOR_TYPE)
|
2015-06-28 00:03:03 -07:00
|
|
|
parse(array, node.lhs.node_index, breakpoints, type_left, false);
|
|
|
|
else if(node.lhs.subtype == DENSE_ARRAY_TYPE)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-12-19 02:04:39 -05:00
|
|
|
if(node.op.type==MATRIX_ROW_TYPE || node.op.type==MATRIX_COLUMN_TYPE || ng1(node.lhs.array->shape())<=1)
|
2015-12-19 00:20:27 -05:00
|
|
|
type_left = ELEMENTWISE_1D;
|
2015-06-28 00:03:03 -07:00
|
|
|
else
|
2015-12-19 00:20:27 -05:00
|
|
|
type_left = ELEMENTWISE_2D;
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
2015-06-28 00:03:03 -07:00
|
|
|
|
|
|
|
//Right
|
|
|
|
expression_type type_right = INVALID_EXPRESSION_TYPE;
|
2015-12-19 02:04:39 -05:00
|
|
|
if (node.rhs.subtype == COMPOSITE_OPERATOR_TYPE)
|
2015-06-28 00:03:03 -07:00
|
|
|
parse(array, node.rhs.node_index, breakpoints, type_right, false);
|
|
|
|
else if(node.rhs.subtype == DENSE_ARRAY_TYPE)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-12-19 02:04:39 -05:00
|
|
|
if(node.op.type==MATRIX_ROW_TYPE || node.op.type==MATRIX_COLUMN_TYPE || ng1(node.rhs.array->shape())<=1)
|
2015-12-19 00:20:27 -05:00
|
|
|
type_right = ELEMENTWISE_1D;
|
2015-06-28 00:03:03 -07:00
|
|
|
else
|
2015-12-19 00:20:27 -05:00
|
|
|
type_right = ELEMENTWISE_2D;
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
2015-06-28 00:03:03 -07:00
|
|
|
|
|
|
|
final_type = merge(array[idx].op, type_left, type_right);
|
|
|
|
std::pair<bool, bool> tmp = has_temporary(array[idx].op, type_left, type_right, is_first);
|
|
|
|
if(tmp.first)
|
|
|
|
breakpoints.push_back(std::make_pair(type_left, &array[idx].lhs));
|
|
|
|
if(tmp.second)
|
|
|
|
breakpoints.push_back(std::make_pair(type_right, &array[idx].rhs));
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2015-12-19 02:55:24 -05:00
|
|
|
/** @brief Executes a expression_tree on the given models map*/
|
2015-09-30 15:31:41 -04:00
|
|
|
void execute(execution_handler const & c, profiles::map_type & profiles)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-12-19 02:55:24 -05:00
|
|
|
expression_tree expression = c.x();
|
2015-04-29 15:50:57 -04:00
|
|
|
driver::Context const & context = expression.context();
|
2015-02-05 04:42:57 -05:00
|
|
|
size_t rootidx = expression.root();
|
2015-12-19 02:55:24 -05:00
|
|
|
expression_tree::container_type & tree = const_cast<expression_tree::container_type &>(expression.tree());
|
|
|
|
expression_tree::node root_save = tree[rootidx];
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
//Todo: technically the datatype should be per temporary
|
2015-09-30 15:31:41 -04:00
|
|
|
numeric_type dtype = expression.dtype();
|
2015-11-19 12:37:18 -05:00
|
|
|
std::vector<std::shared_ptr<array> > temporaries_;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-06-27 17:55:01 -07:00
|
|
|
expression_type final_type;
|
2015-12-19 02:04:39 -05:00
|
|
|
//MATRIX_PRODUCT
|
2015-12-12 18:32:06 -05:00
|
|
|
if(symbolic::preset::matrix_product::args args = symbolic::preset::matrix_product::check(tree, rootidx)){
|
2015-06-27 17:55:01 -07:00
|
|
|
final_type = args.type;
|
|
|
|
}
|
|
|
|
//Default
|
2015-01-12 13:20:53 -05:00
|
|
|
else
|
|
|
|
{
|
2015-06-27 17:55:01 -07:00
|
|
|
detail::breakpoints_t breakpoints;
|
|
|
|
breakpoints.reserve(8);
|
|
|
|
|
|
|
|
//Init
|
|
|
|
expression_type current_type;
|
2015-11-25 18:42:25 -05:00
|
|
|
auto ng1 = [](shape_t const & shape){ size_t res = 0 ; for(size_t i = 0 ; i < shape.size() ; ++i) res += (shape[i] > 1); return res;};
|
|
|
|
if(ng1(expression.shape())<=1)
|
2015-12-19 00:20:27 -05:00
|
|
|
current_type=ELEMENTWISE_1D;
|
2015-06-27 17:55:01 -07:00
|
|
|
else
|
2015-12-19 00:20:27 -05:00
|
|
|
current_type=ELEMENTWISE_2D;
|
2015-06-27 17:55:01 -07:00
|
|
|
final_type = current_type;
|
|
|
|
|
|
|
|
/*----Parse required temporaries-----*/
|
2015-06-28 00:03:03 -07:00
|
|
|
detail::parse(tree, rootidx, breakpoints, final_type);
|
2015-06-27 17:55:01 -07:00
|
|
|
|
|
|
|
/*----Compute required temporaries----*/
|
2015-06-28 00:03:03 -07:00
|
|
|
for(detail::breakpoints_t::iterator it = breakpoints.begin() ; it != breakpoints.end() ; ++it)
|
2015-06-27 17:55:01 -07:00
|
|
|
{
|
2015-08-12 00:46:51 -07:00
|
|
|
std::shared_ptr<profiles::value_type> const & profile = profiles[std::make_pair(it->first, dtype)];
|
2015-12-19 02:55:24 -05:00
|
|
|
expression_tree::node const & node = tree[it->second->node_index];
|
|
|
|
expression_tree::node const & lmost = lhs_most(tree, node);
|
2015-06-27 17:55:01 -07:00
|
|
|
|
|
|
|
//Creates temporary
|
2015-07-28 15:26:10 -07:00
|
|
|
std::shared_ptr<array> tmp;
|
2015-06-28 00:03:03 -07:00
|
|
|
switch(it->first){
|
2015-12-19 00:20:27 -05:00
|
|
|
case REDUCE_1D: tmp = std::shared_ptr<array>(new array(1, dtype, context)); break;
|
2015-07-11 09:36:01 -04:00
|
|
|
|
2015-12-19 00:20:27 -05:00
|
|
|
case ELEMENTWISE_1D: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
|
|
|
case REDUCE_2D_ROWS: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
|
|
|
case REDUCE_2D_COLS: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
|
2015-07-11 09:36:01 -04:00
|
|
|
|
2015-12-19 00:20:27 -05:00
|
|
|
case ELEMENTWISE_2D: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
|
|
|
|
case MATRIX_PRODUCT_NN: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
|
|
|
|
case MATRIX_PRODUCT_NT: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
|
|
|
|
case MATRIX_PRODUCT_TN: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
|
|
|
|
case MATRIX_PRODUCT_TT: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
|
2015-06-27 17:55:01 -07:00
|
|
|
|
|
|
|
default: throw std::invalid_argument("Unrecognized operation");
|
|
|
|
}
|
|
|
|
temporaries_.push_back(tmp);
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-12-19 02:04:39 -05:00
|
|
|
tree[rootidx].op.type = ASSIGN_TYPE;
|
2015-06-27 17:55:01 -07:00
|
|
|
fill(tree[rootidx].lhs, (array&)*tmp);
|
2015-06-28 00:03:03 -07:00
|
|
|
tree[rootidx].rhs = *it->second;
|
2015-12-19 02:04:39 -05:00
|
|
|
tree[rootidx].rhs.subtype = it->second->subtype;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-06-27 17:55:01 -07:00
|
|
|
//Execute
|
2015-09-30 15:31:41 -04:00
|
|
|
profile->execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
2015-06-27 17:55:01 -07:00
|
|
|
tree[rootidx] = root_save;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-12-19 02:55:24 -05:00
|
|
|
//Incorporates the temporary within, the expression_tree
|
2015-06-28 00:03:03 -07:00
|
|
|
fill(*it->second, (array&)*tmp);
|
2015-06-27 17:55:01 -07:00
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
/*-----Compute final expression-----*/
|
2015-09-30 15:31:41 -04:00
|
|
|
profiles[std::make_pair(final_type, dtype)]->execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
|
|
|
}
|
|
|
|
|
|
|
|
void execute(execution_handler const & c)
|
|
|
|
{
|
|
|
|
execute(c, isaac::profiles::get(c.execution_options().queue(c.x().context())));
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|