Bugfix in cast and relational operators
This commit is contained in:
@@ -8,20 +8,7 @@ namespace atidlas
|
||||
namespace detail
|
||||
{
|
||||
|
||||
bool is_node_leaf(op_element const & op)
|
||||
{
|
||||
return op.type==OPERATOR_TRANS_TYPE
|
||||
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
|
||||
|| op.type==OPERATOR_MATRIX_DIAG_TYPE
|
||||
|| op.type==OPERATOR_VDIAG_TYPE
|
||||
|| op.type==OPERATOR_REPEAT_TYPE
|
||||
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|
||||
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|
||||
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|
||||
|| op.type==OPERATOR_OUTER_PROD_TYPE;
|
||||
}
|
||||
|
||||
|
||||
bool is_scalar_reduction(symbolic_expression_node const & node)
|
||||
{
|
||||
@@ -44,13 +31,50 @@ namespace detail
|
||||
|| op.type== OPERATOR_ELEMENT_PROD_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_DIV_TYPE
|
||||
|| op.type== OPERATOR_MULT_TYPE
|
||||
|| op.type== OPERATOR_DIV_TYPE;
|
||||
|| op.type== OPERATOR_DIV_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_EQ_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_NEQ_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_GREATER_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_LESS_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_GEQ_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_LEQ_TYPE ;
|
||||
}
|
||||
|
||||
|
||||
bool is_cast(op_element const & op)
|
||||
{
|
||||
return op.type== OPERATOR_CAST_CHAR_TYPE
|
||||
|| op.type== OPERATOR_CAST_UCHAR_TYPE
|
||||
|| op.type== OPERATOR_CAST_SHORT_TYPE
|
||||
|| op.type== OPERATOR_CAST_USHORT_TYPE
|
||||
|| op.type== OPERATOR_CAST_INT_TYPE
|
||||
|| op.type== OPERATOR_CAST_UINT_TYPE
|
||||
|| op.type== OPERATOR_CAST_LONG_TYPE
|
||||
|| op.type== OPERATOR_CAST_ULONG_TYPE
|
||||
|| op.type== OPERATOR_CAST_FLOAT_TYPE
|
||||
|| op.type== OPERATOR_CAST_DOUBLE_TYPE
|
||||
;
|
||||
}
|
||||
|
||||
bool is_node_leaf(op_element const & op)
|
||||
{
|
||||
return op.type==OPERATOR_TRANS_TYPE
|
||||
|| op.type==OPERATOR_MATRIX_DIAG_TYPE
|
||||
|| op.type==OPERATOR_VDIAG_TYPE
|
||||
|| op.type==OPERATOR_REPEAT_TYPE
|
||||
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|
||||
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|
||||
|| op.type==OPERATOR_OUTER_PROD_TYPE
|
||||
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
|
||||
;
|
||||
}
|
||||
|
||||
bool is_elementwise_function(op_element const & op)
|
||||
{
|
||||
return
|
||||
op.type == OPERATOR_CAST_CHAR_TYPE
|
||||
return op.type == OPERATOR_CAST_CHAR_TYPE
|
||||
|| op.type == OPERATOR_CAST_UCHAR_TYPE
|
||||
|| op.type == OPERATOR_CAST_SHORT_TYPE
|
||||
|| op.type == OPERATOR_CAST_USHORT_TYPE
|
||||
@@ -81,12 +105,6 @@ namespace detail
|
||||
|| op.type== OPERATOR_TANH_TYPE
|
||||
|
||||
|| op.type== OPERATOR_ELEMENT_POW_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_EQ_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_NEQ_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_GREATER_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_LESS_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_GEQ_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_LEQ_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_FMAX_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_FMIN_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_MAX_TYPE
|
||||
@@ -94,6 +112,8 @@ namespace detail
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
//
|
||||
filter_fun::filter_fun(pred_t pred, std::vector<size_t> & out) : pred_(pred), out_(out)
|
||||
@@ -161,18 +181,6 @@ const char * evaluate(operation_node_type type)
|
||||
case OPERATOR_TAN_TYPE : return "tan";
|
||||
case OPERATOR_TANH_TYPE : return "tanh";
|
||||
|
||||
case OPERATOR_CAST_CHAR_TYPE : return "(char)";
|
||||
case OPERATOR_CAST_UCHAR_TYPE : return "(uchar)";
|
||||
case OPERATOR_CAST_SHORT_TYPE : return "(short)";
|
||||
case OPERATOR_CAST_USHORT_TYPE : return "(ushort)";
|
||||
case OPERATOR_CAST_INT_TYPE : return "(int)";
|
||||
case OPERATOR_CAST_UINT_TYPE : return "(uint)";
|
||||
case OPERATOR_CAST_LONG_TYPE : return "(long)";
|
||||
case OPERATOR_CAST_ULONG_TYPE : return "(ulong)";
|
||||
case OPERATOR_CAST_HALF_TYPE : return "(half)";
|
||||
case OPERATOR_CAST_FLOAT_TYPE : return "(float)";
|
||||
case OPERATOR_CAST_DOUBLE_TYPE : return "(double)";
|
||||
|
||||
case OPERATOR_ELEMENT_ARGFMAX_TYPE : return "argfmax";
|
||||
case OPERATOR_ELEMENT_ARGMAX_TYPE : return "argmax";
|
||||
case OPERATOR_ELEMENT_ARGFMIN_TYPE : return "argfmin";
|
||||
@@ -193,12 +201,12 @@ const char * evaluate(operation_node_type type)
|
||||
case OPERATOR_ACCESS_TYPE : return "[]";
|
||||
|
||||
//Relational
|
||||
case OPERATOR_ELEMENT_EQ_TYPE : return "isequal";
|
||||
case OPERATOR_ELEMENT_NEQ_TYPE : return "isnotequal";
|
||||
case OPERATOR_ELEMENT_GREATER_TYPE : return "isgreater";
|
||||
case OPERATOR_ELEMENT_GEQ_TYPE : return "isgreaterequal";
|
||||
case OPERATOR_ELEMENT_LESS_TYPE : return "isless";
|
||||
case OPERATOR_ELEMENT_LEQ_TYPE : return "islessequal";
|
||||
case OPERATOR_ELEMENT_EQ_TYPE : return "==";
|
||||
case OPERATOR_ELEMENT_NEQ_TYPE : return "!=";
|
||||
case OPERATOR_ELEMENT_GREATER_TYPE : return ">";
|
||||
case OPERATOR_ELEMENT_GEQ_TYPE : return ">=";
|
||||
case OPERATOR_ELEMENT_LESS_TYPE : return "<";
|
||||
case OPERATOR_ELEMENT_LEQ_TYPE : return "<=";
|
||||
|
||||
case OPERATOR_ELEMENT_FMAX_TYPE : return "fmax";
|
||||
case OPERATOR_ELEMENT_FMIN_TYPE : return "fmin";
|
||||
@@ -261,7 +269,9 @@ evaluate_expression_traversal::evaluate_expression_traversal(std::map<std::strin
|
||||
void evaluate_expression_traversal::call_before_expansion(atidlas::symbolic_expression const & symbolic_expression, int_t root_idx) const
|
||||
{
|
||||
symbolic_expression_node const & root_node = symbolic_expression.tree()[root_idx];
|
||||
if ((root_node.op.type_family==OPERATOR_UNARY_TYPE_FAMILY || detail::is_elementwise_function(root_node.op))
|
||||
if(detail::is_cast(root_node.op))
|
||||
str_ += mapping_.at(std::make_pair(root_idx, PARENT_NODE_TYPE))->evaluate(accessors_);
|
||||
else if ((root_node.op.type_family==OPERATOR_UNARY_TYPE_FAMILY || detail::is_elementwise_function(root_node.op))
|
||||
&& !detail::is_node_leaf(root_node.op))
|
||||
str_+=evaluate(root_node.op.type);
|
||||
str_+="(";
|
||||
|
Reference in New Issue
Block a user