Bugfix in cast and relational operators

This commit is contained in:
Philippe Tillet
2015-01-29 01:00:50 -05:00
parent c7665021d1
commit d4629ba018
13 changed files with 198 additions and 125 deletions

View File

@@ -8,20 +8,7 @@ namespace atidlas
namespace detail
{
bool is_node_leaf(op_element const & op)
{
return op.type==OPERATOR_TRANS_TYPE
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
|| op.type==OPERATOR_MATRIX_DIAG_TYPE
|| op.type==OPERATOR_VDIAG_TYPE
|| op.type==OPERATOR_REPEAT_TYPE
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|| op.type==OPERATOR_OUTER_PROD_TYPE;
}
bool is_scalar_reduction(symbolic_expression_node const & node)
{
@@ -44,13 +31,50 @@ namespace detail
|| op.type== OPERATOR_ELEMENT_PROD_TYPE
|| op.type== OPERATOR_ELEMENT_DIV_TYPE
|| op.type== OPERATOR_MULT_TYPE
|| op.type== OPERATOR_DIV_TYPE;
|| op.type== OPERATOR_DIV_TYPE
|| op.type== OPERATOR_ELEMENT_EQ_TYPE
|| op.type== OPERATOR_ELEMENT_NEQ_TYPE
|| op.type== OPERATOR_ELEMENT_GREATER_TYPE
|| op.type== OPERATOR_ELEMENT_LESS_TYPE
|| op.type== OPERATOR_ELEMENT_GEQ_TYPE
|| op.type== OPERATOR_ELEMENT_LEQ_TYPE ;
}
bool is_cast(op_element const & op)
{
return op.type== OPERATOR_CAST_CHAR_TYPE
|| op.type== OPERATOR_CAST_UCHAR_TYPE
|| op.type== OPERATOR_CAST_SHORT_TYPE
|| op.type== OPERATOR_CAST_USHORT_TYPE
|| op.type== OPERATOR_CAST_INT_TYPE
|| op.type== OPERATOR_CAST_UINT_TYPE
|| op.type== OPERATOR_CAST_LONG_TYPE
|| op.type== OPERATOR_CAST_ULONG_TYPE
|| op.type== OPERATOR_CAST_FLOAT_TYPE
|| op.type== OPERATOR_CAST_DOUBLE_TYPE
;
}
bool is_node_leaf(op_element const & op)
{
return op.type==OPERATOR_TRANS_TYPE
|| op.type==OPERATOR_MATRIX_DIAG_TYPE
|| op.type==OPERATOR_VDIAG_TYPE
|| op.type==OPERATOR_REPEAT_TYPE
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|| op.type==OPERATOR_OUTER_PROD_TYPE
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
;
}
bool is_elementwise_function(op_element const & op)
{
return
op.type == OPERATOR_CAST_CHAR_TYPE
return op.type == OPERATOR_CAST_CHAR_TYPE
|| op.type == OPERATOR_CAST_UCHAR_TYPE
|| op.type == OPERATOR_CAST_SHORT_TYPE
|| op.type == OPERATOR_CAST_USHORT_TYPE
@@ -81,12 +105,6 @@ namespace detail
|| op.type== OPERATOR_TANH_TYPE
|| op.type== OPERATOR_ELEMENT_POW_TYPE
|| op.type== OPERATOR_ELEMENT_EQ_TYPE
|| op.type== OPERATOR_ELEMENT_NEQ_TYPE
|| op.type== OPERATOR_ELEMENT_GREATER_TYPE
|| op.type== OPERATOR_ELEMENT_LESS_TYPE
|| op.type== OPERATOR_ELEMENT_GEQ_TYPE
|| op.type== OPERATOR_ELEMENT_LEQ_TYPE
|| op.type== OPERATOR_ELEMENT_FMAX_TYPE
|| op.type== OPERATOR_ELEMENT_FMIN_TYPE
|| op.type== OPERATOR_ELEMENT_MAX_TYPE
@@ -94,6 +112,8 @@ namespace detail
}
}
//
filter_fun::filter_fun(pred_t pred, std::vector<size_t> & out) : pred_(pred), out_(out)
@@ -161,18 +181,6 @@ const char * evaluate(operation_node_type type)
case OPERATOR_TAN_TYPE : return "tan";
case OPERATOR_TANH_TYPE : return "tanh";
case OPERATOR_CAST_CHAR_TYPE : return "(char)";
case OPERATOR_CAST_UCHAR_TYPE : return "(uchar)";
case OPERATOR_CAST_SHORT_TYPE : return "(short)";
case OPERATOR_CAST_USHORT_TYPE : return "(ushort)";
case OPERATOR_CAST_INT_TYPE : return "(int)";
case OPERATOR_CAST_UINT_TYPE : return "(uint)";
case OPERATOR_CAST_LONG_TYPE : return "(long)";
case OPERATOR_CAST_ULONG_TYPE : return "(ulong)";
case OPERATOR_CAST_HALF_TYPE : return "(half)";
case OPERATOR_CAST_FLOAT_TYPE : return "(float)";
case OPERATOR_CAST_DOUBLE_TYPE : return "(double)";
case OPERATOR_ELEMENT_ARGFMAX_TYPE : return "argfmax";
case OPERATOR_ELEMENT_ARGMAX_TYPE : return "argmax";
case OPERATOR_ELEMENT_ARGFMIN_TYPE : return "argfmin";
@@ -193,12 +201,12 @@ const char * evaluate(operation_node_type type)
case OPERATOR_ACCESS_TYPE : return "[]";
//Relational
case OPERATOR_ELEMENT_EQ_TYPE : return "isequal";
case OPERATOR_ELEMENT_NEQ_TYPE : return "isnotequal";
case OPERATOR_ELEMENT_GREATER_TYPE : return "isgreater";
case OPERATOR_ELEMENT_GEQ_TYPE : return "isgreaterequal";
case OPERATOR_ELEMENT_LESS_TYPE : return "isless";
case OPERATOR_ELEMENT_LEQ_TYPE : return "islessequal";
case OPERATOR_ELEMENT_EQ_TYPE : return "==";
case OPERATOR_ELEMENT_NEQ_TYPE : return "!=";
case OPERATOR_ELEMENT_GREATER_TYPE : return ">";
case OPERATOR_ELEMENT_GEQ_TYPE : return ">=";
case OPERATOR_ELEMENT_LESS_TYPE : return "<";
case OPERATOR_ELEMENT_LEQ_TYPE : return "<=";
case OPERATOR_ELEMENT_FMAX_TYPE : return "fmax";
case OPERATOR_ELEMENT_FMIN_TYPE : return "fmin";
@@ -261,7 +269,9 @@ evaluate_expression_traversal::evaluate_expression_traversal(std::map<std::strin
void evaluate_expression_traversal::call_before_expansion(atidlas::symbolic_expression const & symbolic_expression, int_t root_idx) const
{
symbolic_expression_node const & root_node = symbolic_expression.tree()[root_idx];
if ((root_node.op.type_family==OPERATOR_UNARY_TYPE_FAMILY || detail::is_elementwise_function(root_node.op))
if(detail::is_cast(root_node.op))
str_ += mapping_.at(std::make_pair(root_idx, PARENT_NODE_TYPE))->evaluate(accessors_);
else if ((root_node.op.type_family==OPERATOR_UNARY_TYPE_FAMILY || detail::is_elementwise_function(root_node.op))
&& !detail::is_node_leaf(root_node.op))
str_+=evaluate(root_node.op.type);
str_+="(";