2015-07-25 21:00:18 -07:00
|
|
|
#include <cstring>
|
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
#include "to_string.hpp"
|
|
|
|
|
2015-04-29 15:50:57 -04:00
|
|
|
#include "isaac/array.h"
|
2015-08-04 20:56:05 -07:00
|
|
|
#include "isaac/kernels/parse.h"
|
2015-04-29 15:50:57 -04:00
|
|
|
#include "isaac/exception/operation_not_supported.h"
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-04-29 15:50:57 -04:00
|
|
|
namespace isaac
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
|
|
|
|
namespace detail
|
|
|
|
{
|
|
|
|
|
2015-01-29 01:00:50 -05:00
|
|
|
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
bool is_scalar_dot(math_expression::node const & node)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-07-11 09:36:01 -04:00
|
|
|
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY;
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
bool is_vector_dot(math_expression::node const & node)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-07-11 09:36:01 -04:00
|
|
|
return node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|
|
|
|
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY;
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
2015-06-28 17:53:16 -07:00
|
|
|
bool is_assignment(op_element const & op)
|
|
|
|
{
|
|
|
|
return op.type== OPERATOR_ASSIGN_TYPE
|
|
|
|
|| op.type== OPERATOR_INPLACE_ADD_TYPE
|
|
|
|
|| op.type== OPERATOR_INPLACE_SUB_TYPE;
|
|
|
|
}
|
|
|
|
|
2015-01-12 13:20:53 -05:00
|
|
|
bool is_elementwise_operator(op_element const & op)
|
|
|
|
{
|
2015-06-28 17:53:16 -07:00
|
|
|
return is_assignment(op)
|
2015-01-12 13:20:53 -05:00
|
|
|
|| op.type== OPERATOR_ADD_TYPE
|
|
|
|
|| op.type== OPERATOR_SUB_TYPE
|
|
|
|
|| op.type== OPERATOR_ELEMENT_PROD_TYPE
|
|
|
|
|| op.type== OPERATOR_ELEMENT_DIV_TYPE
|
|
|
|
|| op.type== OPERATOR_MULT_TYPE
|
2015-01-29 01:00:50 -05:00
|
|
|
|| op.type== OPERATOR_DIV_TYPE
|
|
|
|
|| op.type== OPERATOR_ELEMENT_EQ_TYPE
|
|
|
|
|| op.type== OPERATOR_ELEMENT_NEQ_TYPE
|
|
|
|
|| op.type== OPERATOR_ELEMENT_GREATER_TYPE
|
|
|
|
|| op.type== OPERATOR_ELEMENT_LESS_TYPE
|
|
|
|
|| op.type== OPERATOR_ELEMENT_GEQ_TYPE
|
|
|
|
|| op.type== OPERATOR_ELEMENT_LEQ_TYPE ;
|
|
|
|
}
|
|
|
|
|
2015-04-29 15:50:57 -04:00
|
|
|
bool bypass(op_element const & op)
|
|
|
|
{
|
2015-06-30 17:55:57 -04:00
|
|
|
return op.type == OPERATOR_RESHAPE_TYPE
|
|
|
|
||op.type == OPERATOR_TRANS_TYPE;
|
2015-04-29 15:50:57 -04:00
|
|
|
}
|
2015-01-29 01:00:50 -05:00
|
|
|
|
|
|
|
bool is_cast(op_element const & op)
|
|
|
|
{
|
2015-01-29 15:19:40 -05:00
|
|
|
return op.type== OPERATOR_CAST_BOOL_TYPE
|
|
|
|
|| op.type== OPERATOR_CAST_CHAR_TYPE
|
2015-01-29 01:00:50 -05:00
|
|
|
|| op.type== OPERATOR_CAST_UCHAR_TYPE
|
|
|
|
|| op.type== OPERATOR_CAST_SHORT_TYPE
|
|
|
|
|| op.type== OPERATOR_CAST_USHORT_TYPE
|
|
|
|
|| op.type== OPERATOR_CAST_INT_TYPE
|
|
|
|
|| op.type== OPERATOR_CAST_UINT_TYPE
|
|
|
|
|| op.type== OPERATOR_CAST_LONG_TYPE
|
|
|
|
|| op.type== OPERATOR_CAST_ULONG_TYPE
|
|
|
|
|| op.type== OPERATOR_CAST_FLOAT_TYPE
|
|
|
|
|| op.type== OPERATOR_CAST_DOUBLE_TYPE
|
|
|
|
;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool is_node_leaf(op_element const & op)
|
|
|
|
{
|
2015-06-30 17:55:57 -04:00
|
|
|
return op.type==OPERATOR_MATRIX_DIAG_TYPE
|
2015-01-29 01:00:50 -05:00
|
|
|
|| op.type==OPERATOR_VDIAG_TYPE
|
|
|
|
|| op.type==OPERATOR_REPEAT_TYPE
|
|
|
|
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|
|
|
|
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|
2015-09-30 15:31:41 -04:00
|
|
|
|| op.type==OPERATOR_ACCESS_INDEX_TYPE
|
2015-01-29 01:00:50 -05:00
|
|
|
|| op.type==OPERATOR_OUTER_PROD_TYPE
|
2015-07-11 09:36:01 -04:00
|
|
|
|| op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|
|
|
|
|| op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|
|
|
|
|| op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|
|
|
|
|| op.type_family==OPERATOR_GEMM_TYPE_FAMILY
|
2015-01-29 01:00:50 -05:00
|
|
|
;
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
bool is_elementwise_function(op_element const & op)
|
|
|
|
{
|
2015-01-29 15:19:40 -05:00
|
|
|
return is_cast(op)
|
2015-01-12 13:20:53 -05:00
|
|
|
|| op.type== OPERATOR_ABS_TYPE
|
|
|
|
|| op.type== OPERATOR_ACOS_TYPE
|
|
|
|
|| op.type== OPERATOR_ASIN_TYPE
|
|
|
|
|| op.type== OPERATOR_ATAN_TYPE
|
|
|
|
|| op.type== OPERATOR_CEIL_TYPE
|
|
|
|
|| op.type== OPERATOR_COS_TYPE
|
|
|
|
|| op.type== OPERATOR_COSH_TYPE
|
|
|
|
|| op.type== OPERATOR_EXP_TYPE
|
|
|
|
|| op.type== OPERATOR_FABS_TYPE
|
|
|
|
|| op.type== OPERATOR_FLOOR_TYPE
|
|
|
|
|| op.type== OPERATOR_LOG_TYPE
|
|
|
|
|| op.type== OPERATOR_LOG10_TYPE
|
|
|
|
|| op.type== OPERATOR_SIN_TYPE
|
|
|
|
|| op.type== OPERATOR_SINH_TYPE
|
|
|
|
|| op.type== OPERATOR_SQRT_TYPE
|
|
|
|
|| op.type== OPERATOR_TAN_TYPE
|
|
|
|
|| op.type== OPERATOR_TANH_TYPE
|
|
|
|
|
|
|
|
|| op.type== OPERATOR_ELEMENT_POW_TYPE
|
|
|
|
|| op.type== OPERATOR_ELEMENT_FMAX_TYPE
|
|
|
|
|| op.type== OPERATOR_ELEMENT_FMIN_TYPE
|
|
|
|
|| op.type== OPERATOR_ELEMENT_MAX_TYPE
|
|
|
|
|| op.type== OPERATOR_ELEMENT_MIN_TYPE;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2015-01-29 01:00:50 -05:00
|
|
|
|
|
|
|
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
//
|
|
|
|
filter_fun::filter_fun(pred_t pred, std::vector<size_t> & out) : pred_(pred), out_(out)
|
|
|
|
{ }
|
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
void filter_fun::operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t leaf) const
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-09-30 15:31:41 -04:00
|
|
|
math_expression::node const * root_node = &math_expression.tree()[root_idx];
|
|
|
|
if (leaf==PARENT_NODE_TYPE && pred_(*root_node))
|
2015-01-12 13:20:53 -05:00
|
|
|
out_.push_back(root_idx);
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
2015-09-30 15:31:41 -04:00
|
|
|
std::vector<size_t> filter_nodes(bool (*pred)(math_expression::node const & node), isaac::math_expression const & math_expression, size_t root, bool inspect)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
std::vector<size_t> res;
|
2015-09-30 15:31:41 -04:00
|
|
|
traverse(math_expression, root, filter_fun(pred, res), inspect);
|
2015-01-12 13:20:53 -05:00
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
2015-09-30 15:31:41 -04:00
|
|
|
filter_elements_fun::filter_elements_fun(math_expression_node_subtype subtype, std::vector<lhs_rhs_element> & out) :
|
2015-01-12 13:20:53 -05:00
|
|
|
subtype_(subtype), out_(out)
|
|
|
|
{ }
|
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
void filter_elements_fun::operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t) const
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-09-30 15:31:41 -04:00
|
|
|
math_expression::node const * root_node = &math_expression.tree()[root_idx];
|
2015-01-12 13:20:53 -05:00
|
|
|
if (root_node->lhs.subtype==subtype_)
|
|
|
|
out_.push_back(root_node->lhs);
|
|
|
|
if (root_node->rhs.subtype==subtype_)
|
|
|
|
out_.push_back(root_node->rhs);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
std::vector<lhs_rhs_element> filter_elements(math_expression_node_subtype subtype, isaac::math_expression const & math_expression)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
std::vector<lhs_rhs_element> res;
|
2015-09-30 15:31:41 -04:00
|
|
|
traverse(math_expression, math_expression.root(), filter_elements_fun(subtype, res), true);
|
2015-01-12 13:20:53 -05:00
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
|
|
|
/** @brief generate a string from an operation_node_type */
|
|
|
|
const char * evaluate(operation_node_type type)
|
|
|
|
{
|
|
|
|
// unary expression
|
|
|
|
switch (type)
|
|
|
|
{
|
|
|
|
//Function
|
|
|
|
case OPERATOR_ABS_TYPE : return "abs";
|
|
|
|
case OPERATOR_ACOS_TYPE : return "acos";
|
|
|
|
case OPERATOR_ASIN_TYPE : return "asin";
|
|
|
|
case OPERATOR_ATAN_TYPE : return "atan";
|
|
|
|
case OPERATOR_CEIL_TYPE : return "ceil";
|
|
|
|
case OPERATOR_COS_TYPE : return "cos";
|
|
|
|
case OPERATOR_COSH_TYPE : return "cosh";
|
|
|
|
case OPERATOR_EXP_TYPE : return "exp";
|
|
|
|
case OPERATOR_FABS_TYPE : return "fabs";
|
|
|
|
case OPERATOR_FLOOR_TYPE : return "floor";
|
|
|
|
case OPERATOR_LOG_TYPE : return "log";
|
|
|
|
case OPERATOR_LOG10_TYPE : return "log10";
|
|
|
|
case OPERATOR_SIN_TYPE : return "sin";
|
|
|
|
case OPERATOR_SINH_TYPE : return "sinh";
|
|
|
|
case OPERATOR_SQRT_TYPE : return "sqrt";
|
|
|
|
case OPERATOR_TAN_TYPE : return "tan";
|
|
|
|
case OPERATOR_TANH_TYPE : return "tanh";
|
|
|
|
|
|
|
|
case OPERATOR_ELEMENT_ARGFMAX_TYPE : return "argfmax";
|
|
|
|
case OPERATOR_ELEMENT_ARGMAX_TYPE : return "argmax";
|
|
|
|
case OPERATOR_ELEMENT_ARGFMIN_TYPE : return "argfmin";
|
|
|
|
case OPERATOR_ELEMENT_ARGMIN_TYPE : return "argmin";
|
|
|
|
case OPERATOR_ELEMENT_POW_TYPE : return "pow";
|
|
|
|
|
2015-11-19 12:37:18 -05:00
|
|
|
//Arithmetic
|
2015-01-12 13:20:53 -05:00
|
|
|
case OPERATOR_MINUS_TYPE : return "-";
|
|
|
|
case OPERATOR_ASSIGN_TYPE : return "=";
|
|
|
|
case OPERATOR_INPLACE_ADD_TYPE : return "+=";
|
|
|
|
case OPERATOR_INPLACE_SUB_TYPE : return "-=";
|
|
|
|
case OPERATOR_ADD_TYPE : return "+";
|
|
|
|
case OPERATOR_SUB_TYPE : return "-";
|
|
|
|
case OPERATOR_MULT_TYPE : return "*";
|
|
|
|
case OPERATOR_ELEMENT_PROD_TYPE : return "*";
|
|
|
|
case OPERATOR_DIV_TYPE : return "/";
|
|
|
|
case OPERATOR_ELEMENT_DIV_TYPE : return "/";
|
|
|
|
|
2015-11-19 12:37:18 -05:00
|
|
|
//Relational
|
2015-01-29 15:19:40 -05:00
|
|
|
case OPERATOR_NEGATE_TYPE: return "!";
|
2015-01-29 01:00:50 -05:00
|
|
|
case OPERATOR_ELEMENT_EQ_TYPE : return "==";
|
|
|
|
case OPERATOR_ELEMENT_NEQ_TYPE : return "!=";
|
|
|
|
case OPERATOR_ELEMENT_GREATER_TYPE : return ">";
|
|
|
|
case OPERATOR_ELEMENT_GEQ_TYPE : return ">=";
|
|
|
|
case OPERATOR_ELEMENT_LESS_TYPE : return "<";
|
|
|
|
case OPERATOR_ELEMENT_LEQ_TYPE : return "<=";
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
case OPERATOR_ELEMENT_FMAX_TYPE : return "fmax";
|
|
|
|
case OPERATOR_ELEMENT_FMIN_TYPE : return "fmin";
|
|
|
|
case OPERATOR_ELEMENT_MAX_TYPE : return "max";
|
|
|
|
case OPERATOR_ELEMENT_MIN_TYPE : return "min";
|
|
|
|
|
2015-11-19 12:37:18 -05:00
|
|
|
//Binary
|
2015-07-11 09:36:01 -04:00
|
|
|
case OPERATOR_GEMM_NN_TYPE : return "prodNN";
|
|
|
|
case OPERATOR_GEMM_TN_TYPE : return "prodTN";
|
|
|
|
case OPERATOR_GEMM_NT_TYPE : return "prodNT";
|
|
|
|
case OPERATOR_GEMM_TT_TYPE : return "prodTT";
|
2015-01-17 10:48:02 -05:00
|
|
|
case OPERATOR_VDIAG_TYPE : return "vdiag";
|
2015-01-12 13:20:53 -05:00
|
|
|
case OPERATOR_MATRIX_DIAG_TYPE : return "mdiag";
|
|
|
|
case OPERATOR_MATRIX_ROW_TYPE : return "row";
|
|
|
|
case OPERATOR_MATRIX_COLUMN_TYPE : return "col";
|
|
|
|
case OPERATOR_PAIR_TYPE: return "pair";
|
2015-09-30 15:31:41 -04:00
|
|
|
case OPERATOR_ACCESS_INDEX_TYPE: return "access";
|
|
|
|
|
2015-11-19 12:37:18 -05:00
|
|
|
//FOR
|
2015-09-30 15:31:41 -04:00
|
|
|
case OPERATOR_SFOR_TYPE: return "sfor";
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
default : throw operation_not_supported_exception("Unsupported operator");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
evaluate_expression_traversal::evaluate_expression_traversal(std::map<std::string, std::string> const & accessors, std::string & str, mapping_type const & mapping) :
|
|
|
|
accessors_(accessors), str_(str), mapping_(mapping)
|
|
|
|
{ }
|
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
void evaluate_expression_traversal::call_before_expansion(isaac::math_expression const & math_expression, std::size_t root_idx) const
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-09-30 15:31:41 -04:00
|
|
|
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
2015-01-29 01:00:50 -05:00
|
|
|
if(detail::is_cast(root_node.op))
|
|
|
|
str_ += mapping_.at(std::make_pair(root_idx, PARENT_NODE_TYPE))->evaluate(accessors_);
|
2015-04-29 15:50:57 -04:00
|
|
|
else if (( (root_node.op.type_family==OPERATOR_UNARY_TYPE_FAMILY&&root_node.op.type!=OPERATOR_ADD_TYPE) || detail::is_elementwise_function(root_node.op))
|
2015-01-12 13:20:53 -05:00
|
|
|
&& !detail::is_node_leaf(root_node.op))
|
|
|
|
str_+=evaluate(root_node.op.type);
|
2015-09-30 15:31:41 -04:00
|
|
|
if(root_node.op.type!=OPERATOR_FUSE)
|
|
|
|
str_+="(";
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
}
|
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
void evaluate_expression_traversal::call_after_expansion(math_expression const & math_expression, std::size_t root_idx) const
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-09-30 15:31:41 -04:00
|
|
|
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
|
|
|
if(root_node.op.type!=OPERATOR_FUSE)
|
|
|
|
str_+=")";
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
void evaluate_expression_traversal::operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf) const
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-09-30 15:31:41 -04:00
|
|
|
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
2015-01-12 13:20:53 -05:00
|
|
|
mapping_type::key_type key = std::make_pair(root_idx, leaf);
|
2015-01-20 11:17:42 -05:00
|
|
|
if (leaf==PARENT_NODE_TYPE)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
if (detail::is_node_leaf(root_node.op))
|
|
|
|
str_ += mapping_.at(key)->evaluate(accessors_);
|
2015-01-20 11:17:42 -05:00
|
|
|
else if(root_node.op.type_family!=OPERATOR_UNARY_TYPE_FAMILY)
|
|
|
|
{
|
|
|
|
if (detail::is_elementwise_operator(root_node.op))
|
|
|
|
str_ += evaluate(root_node.op.type);
|
|
|
|
else if (detail::is_elementwise_function(root_node.op))
|
|
|
|
str_ += ",";
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
if (leaf==LHS_NODE_TYPE)
|
|
|
|
{
|
|
|
|
if (root_node.lhs.type_family!=COMPOSITE_OPERATOR_FAMILY)
|
2015-09-30 15:31:41 -04:00
|
|
|
{
|
|
|
|
if (root_node.lhs.subtype==FOR_LOOP_INDEX_TYPE)
|
|
|
|
str_ += "sforidx" + tools::to_string(root_node.lhs.for_idx.level);
|
|
|
|
else
|
|
|
|
str_ += mapping_.at(key)->evaluate(accessors_);
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
if (leaf==RHS_NODE_TYPE)
|
|
|
|
{
|
|
|
|
if (root_node.rhs.type_family!=COMPOSITE_OPERATOR_FAMILY)
|
2015-09-30 15:31:41 -04:00
|
|
|
{
|
|
|
|
if (root_node.rhs.subtype==FOR_LOOP_INDEX_TYPE)
|
|
|
|
str_ += "sforidx" + tools::to_string(root_node.rhs.for_idx.level);
|
|
|
|
else
|
|
|
|
str_ += mapping_.at(key)->evaluate(accessors_);
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::string evaluate(leaf_t leaf, std::map<std::string, std::string> const & accessors,
|
2015-09-30 15:31:41 -04:00
|
|
|
isaac::math_expression const & math_expression, std::size_t root_idx, mapping_type const & mapping)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
std::string res;
|
|
|
|
evaluate_expression_traversal traversal_functor(accessors, res, mapping);
|
2015-09-30 15:31:41 -04:00
|
|
|
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
if (leaf==RHS_NODE_TYPE)
|
|
|
|
{
|
|
|
|
if (root_node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
2015-09-30 15:31:41 -04:00
|
|
|
traverse(math_expression, root_node.rhs.node_index, traversal_functor, false);
|
2015-01-12 13:20:53 -05:00
|
|
|
else
|
2015-09-30 15:31:41 -04:00
|
|
|
traversal_functor(math_expression, root_idx, leaf);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
else if (leaf==LHS_NODE_TYPE)
|
|
|
|
{
|
|
|
|
if (root_node.lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
2015-09-30 15:31:41 -04:00
|
|
|
traverse(math_expression, root_node.lhs.node_index, traversal_functor, false);
|
2015-01-12 13:20:53 -05:00
|
|
|
else
|
2015-09-30 15:31:41 -04:00
|
|
|
traversal_functor(math_expression, root_idx, leaf);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
else
|
2015-09-30 15:31:41 -04:00
|
|
|
traverse(math_expression, root_idx, traversal_functor, false);
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
|
|
|
void evaluate(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
|
2015-09-30 15:31:41 -04:00
|
|
|
math_expression const & x, mapping_type const & mapping)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-09-30 15:31:41 -04:00
|
|
|
stream << evaluate(leaf, accessors, x, x.root(), mapping) << std::endl;
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
2015-01-17 10:48:02 -05:00
|
|
|
process_traversal::process_traversal(std::map<std::string, std::string> const & accessors, kernel_generation_stream & stream,
|
2015-01-12 13:20:53 -05:00
|
|
|
mapping_type const & mapping, std::set<std::string> & already_processed) :
|
|
|
|
accessors_(accessors), stream_(stream), mapping_(mapping), already_processed_(already_processed)
|
|
|
|
{ }
|
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
void process_traversal::operator()(math_expression const & /*math_expression*/, std::size_t root_idx, leaf_t leaf) const
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
mapping_type::const_iterator it = mapping_.find(std::make_pair(root_idx, leaf));
|
|
|
|
if (it!=mapping_.end())
|
|
|
|
{
|
|
|
|
mapped_object * obj = it->second.get();
|
|
|
|
std::string name = obj->name();
|
2015-01-17 10:48:02 -05:00
|
|
|
|
|
|
|
if(accessors_.find(name)!=accessors_.end() && already_processed_.insert(name).second)
|
|
|
|
for(std::map<std::string, std::string>::const_iterator itt = accessors_.lower_bound(name) ; itt != accessors_.upper_bound(name) ; ++itt)
|
|
|
|
stream_ << obj->process(itt->second) << std::endl;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
std::string key = obj->type_key();
|
2015-01-17 10:48:02 -05:00
|
|
|
if(accessors_.find(key)!=accessors_.end() && already_processed_.insert(name).second)
|
|
|
|
for(std::map<std::string, std::string>::const_iterator itt = accessors_.lower_bound(key) ; itt != accessors_.upper_bound(key) ; ++itt)
|
|
|
|
stream_ << obj->process(itt->second) << std::endl;
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2015-01-17 10:48:02 -05:00
|
|
|
void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
|
2015-09-30 15:31:41 -04:00
|
|
|
isaac::math_expression const & math_expression, size_t root_idx, mapping_type const & mapping, std::set<std::string> & already_processed)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
process_traversal traversal_functor(accessors, stream, mapping, already_processed);
|
2015-09-30 15:31:41 -04:00
|
|
|
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
if (leaf==RHS_NODE_TYPE)
|
|
|
|
{
|
|
|
|
if (root_node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
2015-09-30 15:31:41 -04:00
|
|
|
traverse(math_expression, root_node.rhs.node_index, traversal_functor, true);
|
2015-01-12 13:20:53 -05:00
|
|
|
else
|
2015-09-30 15:31:41 -04:00
|
|
|
traversal_functor(math_expression, root_idx, leaf);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
else if (leaf==LHS_NODE_TYPE)
|
|
|
|
{
|
|
|
|
if (root_node.lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
2015-09-30 15:31:41 -04:00
|
|
|
traverse(math_expression, root_node.lhs.node_index, traversal_functor, true);
|
2015-01-12 13:20:53 -05:00
|
|
|
else
|
2015-09-30 15:31:41 -04:00
|
|
|
traversal_functor(math_expression, root_idx, leaf);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
2015-09-30 15:31:41 -04:00
|
|
|
traverse(math_expression, root_idx, traversal_functor, true);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
|
2015-01-17 10:48:02 -05:00
|
|
|
void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
|
2015-09-30 15:31:41 -04:00
|
|
|
math_expression const & x, mapping_type const & mapping)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-09-30 15:31:41 -04:00
|
|
|
std::set<std::string> processed;
|
|
|
|
process(stream, leaf, accessors, x, x.root(), mapping, processed);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
void math_expression_representation_functor::append_id(char * & ptr, unsigned int val)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
if (val==0)
|
|
|
|
*ptr++='0';
|
|
|
|
else
|
|
|
|
while (val>0)
|
|
|
|
{
|
|
|
|
*ptr++= (char)('0' + (val % 10));
|
|
|
|
val /= 10;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
//void math_expression_representation_functor::append(driver::Buffer const & h, numeric_type dtype, char prefix, bool is_assigned) const
|
|
|
|
//{
|
|
|
|
// *ptr_++=prefix;
|
|
|
|
// *ptr_++=(char)dtype;
|
|
|
|
// append_id(ptr_, binder_.get(h, is_assigned));
|
|
|
|
//}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
void math_expression_representation_functor::append(lhs_rhs_element const & lhs_rhs, bool is_assigned) const
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
if(lhs_rhs.subtype==DENSE_ARRAY_TYPE)
|
2015-09-30 15:31:41 -04:00
|
|
|
{
|
2015-10-06 16:34:47 -04:00
|
|
|
char prefix;
|
2015-11-19 12:37:18 -05:00
|
|
|
if(lhs_rhs.array->shape().max()==1)
|
|
|
|
prefix = '0';
|
|
|
|
else if(lhs_rhs.array->dim()==1 || lhs_rhs.array->shape().min()==1)
|
|
|
|
prefix = '1';
|
|
|
|
else
|
|
|
|
prefix = '2';
|
2015-09-30 15:31:41 -04:00
|
|
|
numeric_type dtype = lhs_rhs.array->dtype();
|
|
|
|
*ptr_++=prefix;
|
|
|
|
*ptr_++=(char)dtype;
|
|
|
|
|
2015-11-19 12:37:18 -05:00
|
|
|
append_id(ptr_, binder_.get(lhs_rhs.array, is_assigned));
|
2015-09-30 15:31:41 -04:00
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
math_expression_representation_functor::math_expression_representation_functor(symbolic_binder & binder, char *& ptr) : binder_(binder), ptr_(ptr){ }
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
void math_expression_representation_functor::append(char*& p, const char * str) const
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
std::size_t n = std::strlen(str);
|
|
|
|
std::memcpy(p, str, n);
|
|
|
|
p+=n;
|
|
|
|
}
|
|
|
|
|
2015-09-30 15:31:41 -04:00
|
|
|
void math_expression_representation_functor::operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf_t) const
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-09-30 15:31:41 -04:00
|
|
|
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
2015-01-12 13:20:53 -05:00
|
|
|
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.type_family != COMPOSITE_OPERATOR_FAMILY)
|
2015-09-30 15:31:41 -04:00
|
|
|
append(root_node.lhs, detail::is_assignment(root_node.op));
|
2015-01-12 13:20:53 -05:00
|
|
|
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.type_family != COMPOSITE_OPERATOR_FAMILY)
|
2015-09-30 15:31:41 -04:00
|
|
|
append(root_node.rhs, false);
|
2015-01-12 13:20:53 -05:00
|
|
|
else if (leaf_t==PARENT_NODE_TYPE)
|
|
|
|
append_id(ptr_,root_node.op.type);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
}
|