Bugfix in cast and relational operators
This commit is contained in:
@@ -2,8 +2,8 @@
|
||||
#define ATIDLAS_ARRAY_H_
|
||||
|
||||
#include <iostream>
|
||||
#include "atidlas/types.h"
|
||||
#include <CL/cl.hpp>
|
||||
#include "atidlas/types.h"
|
||||
#include "atidlas/cl/queues.h"
|
||||
#include "atidlas/symbolic/expression.h"
|
||||
|
||||
@@ -118,7 +118,6 @@ public:
|
||||
|
||||
|
||||
//copy
|
||||
|
||||
void copy(void const * data, array & gx, cl::CommandQueue & queue, bool blocking = true);
|
||||
void copy(array const & gx, void* data, cl::CommandQueue & queue, bool blocking = true);
|
||||
void copy(void const *data, array &gx, bool blocking = true);
|
||||
@@ -192,6 +191,9 @@ ATIDLAS_DECLARE_UNARY_OPERATOR(tan)
|
||||
ATIDLAS_DECLARE_UNARY_OPERATOR(tanh)
|
||||
ATIDLAS_DECLARE_UNARY_OPERATOR(trans)
|
||||
|
||||
array_expression cast(array const &, numeric_type dtype);
|
||||
array_expression cast(array_expression const &, numeric_type dtype);
|
||||
|
||||
array_expression norm(array const &, unsigned int order = 2);
|
||||
array_expression norm(array_expression const &, unsigned int order = 2);
|
||||
|
||||
|
@@ -235,6 +235,12 @@ public:
|
||||
mapped_outer(std::string const & scalartype, unsigned int id, node_info info);
|
||||
};
|
||||
|
||||
class mapped_cast : public mapped_object
|
||||
{
|
||||
static std::string operator_to_str(operation_node_type type);
|
||||
public:
|
||||
mapped_cast(operation_node_type type, unsigned int id);
|
||||
};
|
||||
|
||||
}
|
||||
#endif
|
||||
|
@@ -17,6 +17,8 @@ namespace detail
|
||||
bool is_vector_reduction(symbolic_expression_node const & node);
|
||||
bool is_elementwise_operator(op_element const & op);
|
||||
bool is_elementwise_function(op_element const & op);
|
||||
bool is_cast(op_element const & op);
|
||||
|
||||
}
|
||||
|
||||
class scalar;
|
||||
|
@@ -193,9 +193,9 @@ public:
|
||||
typedef std::vector<value_type> container_type;
|
||||
|
||||
symbolic_expression(lhs_rhs_element const & lhs, lhs_rhs_element const & rhs, op_element const & op, cl::Context const & context, numeric_type const & dtype);
|
||||
symbolic_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op);
|
||||
symbolic_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op);
|
||||
symbolic_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op);
|
||||
symbolic_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, numeric_type const & dtype);
|
||||
symbolic_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype);
|
||||
symbolic_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype);
|
||||
|
||||
container_type & tree();
|
||||
container_type const & tree() const;
|
||||
@@ -212,9 +212,9 @@ protected:
|
||||
struct array_expression: public symbolic_expression
|
||||
{
|
||||
array_expression(lhs_rhs_element const & lhs, lhs_rhs_element const & rhs, op_element const & op, cl::Context const & ctx, numeric_type const & dtype, size4 shape);
|
||||
array_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, size4 shape);
|
||||
array_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, size4 shape);
|
||||
array_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, size4 shape);
|
||||
array_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, numeric_type const & dtype, size4 shape);
|
||||
array_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype, size4 shape);
|
||||
array_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype, size4 shape);
|
||||
size4 shape() const;
|
||||
array_expression& reshape(int_t size1, int_t size2=1);
|
||||
int_t nshape() const;
|
||||
|
116
lib/array.cpp
116
lib/array.cpp
@@ -132,6 +132,7 @@ int_t array::dsize() const
|
||||
//---------------------------------------
|
||||
array & array::operator=(array const & rhs)
|
||||
{
|
||||
assert(dtype_ == rhs.dtype());
|
||||
array_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
cl::CommandQueue & queue = cl_ext::get_queue(context_, 0);
|
||||
model_map_t & mmap = atidlas::get_model_map(queue);
|
||||
@@ -141,7 +142,8 @@ array & array::operator=(array const & rhs)
|
||||
|
||||
array & array::operator=(array_expression const & rhs)
|
||||
{
|
||||
array_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), shape_);
|
||||
assert(dtype_ == rhs.dtype());
|
||||
array_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), dtype_, shape_);
|
||||
cl::CommandQueue & queue = cl_ext::get_queue(context_, 0);
|
||||
model_map_t & mmap = atidlas::get_model_map(queue);
|
||||
execute(expression, mmap);
|
||||
@@ -181,7 +183,7 @@ array & array::operator+=(array const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array & array::operator+=(array_expression const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), shape_); }
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), dtype_, shape_); }
|
||||
//----
|
||||
array & array::operator-=(value_scalar const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
|
||||
@@ -190,7 +192,7 @@ array & array::operator-=(array const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array & array::operator-=(array_expression const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), shape_); }
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), dtype_, shape_); }
|
||||
//----
|
||||
array & array::operator*=(value_scalar const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
|
||||
@@ -199,7 +201,7 @@ array & array::operator*=(array const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array & array::operator*=(array_expression const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), shape_); }
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), dtype_, shape_); }
|
||||
//----
|
||||
array & array::operator/=(value_scalar const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
|
||||
@@ -208,7 +210,7 @@ array & array::operator/=(array const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array & array::operator/=(array_expression const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), shape_); }
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), dtype_, shape_); }
|
||||
|
||||
array_expression array::T() const
|
||||
{ return atidlas::trans(*this) ;}
|
||||
@@ -385,51 +387,55 @@ bool check_elementwise(U const & u, V const & v)
|
||||
return max(u.shape())==1 || max(v.shape())==1 || u.shape()==v.shape();
|
||||
}
|
||||
|
||||
|
||||
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME) \
|
||||
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \
|
||||
array_expression OPNAME (array_expression const & x, array_expression const & y) \
|
||||
{ assert(check_elementwise(x, y));\
|
||||
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), elementwise_size(x, y) ); } \
|
||||
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, elementwise_size(x, y)); } \
|
||||
\
|
||||
array_expression OPNAME (array const & x, array_expression const & y) \
|
||||
{ assert(check_elementwise(x, y));\
|
||||
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), elementwise_size(x, y)); } \
|
||||
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, elementwise_size(x, y)); } \
|
||||
\
|
||||
array_expression OPNAME (array_expression const & x, array const & y) \
|
||||
{ assert(check_elementwise(x, y));\
|
||||
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), elementwise_size(x, y)); } \
|
||||
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, elementwise_size(x, y)); } \
|
||||
\
|
||||
array_expression OPNAME (array const & x, array const & y) \
|
||||
{ assert(check_elementwise(x, y));\
|
||||
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), x.dtype(), elementwise_size(x, y)); }\
|
||||
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, elementwise_size(x, y)); }\
|
||||
\
|
||||
array_expression OPNAME (array_expression const & x, value_scalar const & y) \
|
||||
{ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.shape()); } \
|
||||
{ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, x.shape()); } \
|
||||
\
|
||||
array_expression OPNAME (array const & x, value_scalar const & y) \
|
||||
{ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
|
||||
{ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
\
|
||||
array_expression OPNAME (value_scalar const & y, array_expression const & x) \
|
||||
{ return array_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.shape()); } \
|
||||
{ return array_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, x.shape()); } \
|
||||
\
|
||||
array_expression OPNAME (value_scalar const & y, array const & x) \
|
||||
{ return array_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
|
||||
{ return array_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }
|
||||
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ADD_TYPE, operator +)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_SUB_TYPE, operator -)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_MULT_TYPE, operator *)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_DIV_TYPE, operator /)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ADD_TYPE, operator +, x.dtype())
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_SUB_TYPE, operator -, x.dtype())
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_MULT_TYPE, operator *, x.dtype())
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_DIV_TYPE, operator /, x.dtype())
|
||||
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GREATER_TYPE, operator >)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GEQ_TYPE, operator >=)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LESS_TYPE, operator <)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LEQ_TYPE, operator <=)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, maximum, x.dtype())
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, minimum, x.dtype())
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow, x.dtype())
|
||||
|
||||
namespace detail
|
||||
{ DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ASSIGN_TYPE, assign, x.dtype()) }
|
||||
|
||||
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GREATER_TYPE, operator >, INT_TYPE)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GEQ_TYPE, operator >=, INT_TYPE)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LESS_TYPE, operator <, INT_TYPE)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LEQ_TYPE, operator <=, INT_TYPE)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==, INT_TYPE)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=, INT_TYPE)
|
||||
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, maximum)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, minimum)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow)
|
||||
|
||||
array_expression outer(array const & x, array const & y)
|
||||
{
|
||||
@@ -437,10 +443,6 @@ array_expression outer(array const & x, array const & y)
|
||||
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_OUTER_PROD_TYPE), x.context(), x.dtype(), size4(max(x.shape()), max(y.shape())) );
|
||||
}
|
||||
|
||||
namespace detail
|
||||
{
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ASSIGN_TYPE, assign)
|
||||
}
|
||||
|
||||
#undef DEFINE_ELEMENT_BINARY_OPERATOR
|
||||
//---------------------------------------
|
||||
@@ -452,7 +454,7 @@ array_expression OPNAME (array const & x) \
|
||||
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
|
||||
\
|
||||
array_expression OPNAME (array_expression const & x) \
|
||||
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.shape()); }
|
||||
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.dtype(), x.shape()); }
|
||||
|
||||
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?OPERATOR_FABS_TYPE:OPERATOR_ABS_TYPE, abs)
|
||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ACOS_TYPE, acos)
|
||||
@@ -476,15 +478,35 @@ DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_TANH_TYPE, tanh)
|
||||
|
||||
///*--- Misc----*/
|
||||
////---------------------------------------
|
||||
atidlas::array_expression eye(std::size_t M, std::size_t N, atidlas::numeric_type dtype, cl::Context ctx)
|
||||
inline operation_node_type casted(numeric_type dtype)
|
||||
{
|
||||
return array_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N));
|
||||
switch(dtype)
|
||||
{
|
||||
case CHAR_TYPE: return OPERATOR_CAST_CHAR_TYPE;
|
||||
case UCHAR_TYPE: return OPERATOR_CAST_UCHAR_TYPE;
|
||||
case SHORT_TYPE: return OPERATOR_CAST_SHORT_TYPE;
|
||||
case USHORT_TYPE: return OPERATOR_CAST_USHORT_TYPE;
|
||||
case INT_TYPE: return OPERATOR_CAST_INT_TYPE;
|
||||
case UINT_TYPE: return OPERATOR_CAST_UINT_TYPE;
|
||||
case LONG_TYPE: return OPERATOR_CAST_LONG_TYPE;
|
||||
case ULONG_TYPE: return OPERATOR_CAST_ULONG_TYPE;
|
||||
case FLOAT_TYPE: return OPERATOR_CAST_FLOAT_TYPE;
|
||||
case DOUBLE_TYPE: return OPERATOR_CAST_DOUBLE_TYPE;
|
||||
default: throw unknown_datatype(dtype);
|
||||
}
|
||||
}
|
||||
|
||||
array_expression cast(array const & x, numeric_type dtype)
|
||||
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||
|
||||
array_expression cast(array_expression const & x, numeric_type dtype)
|
||||
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), dtype, x.shape()); }
|
||||
|
||||
atidlas::array_expression eye(std::size_t M, std::size_t N, atidlas::numeric_type dtype, cl::Context ctx)
|
||||
{ return array_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N)); }
|
||||
|
||||
atidlas::array_expression zeros(std::size_t M, std::size_t N, atidlas::numeric_type dtype, cl::Context ctx)
|
||||
{
|
||||
return array_expression(value_scalar(0), lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N));
|
||||
}
|
||||
{ return array_expression(value_scalar(0), lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N)); }
|
||||
|
||||
inline size4 flip(size4 const & shape)
|
||||
{ return size4(shape._2, shape._1);}
|
||||
@@ -496,7 +518,7 @@ array_expression trans(array const & x) \
|
||||
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
|
||||
\
|
||||
array_expression trans(array_expression const & x) \
|
||||
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), flip(x.shape())); }
|
||||
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.dtype(), flip(x.shape())); }
|
||||
|
||||
array_expression repmat(array const & A, int_t const & rep1, int_t const & rep2)
|
||||
{
|
||||
@@ -515,7 +537,7 @@ array_expression repmat(array_expression const & A, int_t const & rep1, int_t co
|
||||
infos.rep2 = rep2;
|
||||
infos.sub1 = A.shape()._1;
|
||||
infos.sub2 = A.shape()._2;
|
||||
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), size4(infos.rep1*infos.sub1, infos.rep2*infos.sub2));
|
||||
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.dtype(), size4(infos.rep1*infos.sub1, infos.rep2*infos.sub2));
|
||||
}
|
||||
|
||||
////---------------------------------------
|
||||
@@ -540,11 +562,11 @@ array_expression OPNAME(array_expression const & x, int_t axis)\
|
||||
if(axis < -1 || axis > x.nshape())\
|
||||
throw std::out_of_range("The axis entry is out of bounds");\
|
||||
if(axis==-1)\
|
||||
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY, OP), size4(1));\
|
||||
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY, OP), x.dtype(), size4(1));\
|
||||
else if(axis==0)\
|
||||
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_ROWS_REDUCTION_TYPE_FAMILY, OP), size4(x.shape()._1));\
|
||||
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_ROWS_REDUCTION_TYPE_FAMILY, OP), x.dtype(), size4(x.shape()._1));\
|
||||
else\
|
||||
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OP), size4(x.shape()._2));\
|
||||
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OP), x.dtype(), size4(x.shape()._2));\
|
||||
}
|
||||
|
||||
DEFINE_REDUCTION(OPERATOR_ADD_TYPE, sum)
|
||||
@@ -576,7 +598,7 @@ namespace detail
|
||||
shape._1 = A.shape()._2;
|
||||
}
|
||||
|
||||
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), shape);
|
||||
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.dtype(), shape);
|
||||
symbolic_expression_node & res_root = const_cast<symbolic_expression_node &>(res.tree()[res.root()]);
|
||||
if(A_trans) res_root.lhs = A_root.lhs;
|
||||
return res;
|
||||
@@ -593,7 +615,7 @@ namespace detail
|
||||
type = OPERATOR_MATRIX_PRODUCT_NT_TYPE;
|
||||
shape._2 = B.shape()._1;
|
||||
}
|
||||
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), shape);
|
||||
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.dtype(), shape);
|
||||
symbolic_expression_node & res_root = const_cast<symbolic_expression_node &>(res.tree()[res.root()]);
|
||||
if(B_trans) res_root.rhs = B_root.lhs;
|
||||
return res;
|
||||
@@ -615,7 +637,7 @@ namespace detail
|
||||
else if(!A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_NT_TYPE;
|
||||
else type = OPERATOR_MATRIX_PRODUCT_NN_TYPE;
|
||||
|
||||
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), shape);
|
||||
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.dtype(), shape);
|
||||
symbolic_expression_node & res_root = const_cast<symbolic_expression_node &>(res.tree()[res.root()]);
|
||||
if(A_trans) res_root.lhs = A_root.lhs;
|
||||
if(B_trans) res_root.rhs = B_root.lhs;
|
||||
@@ -639,7 +661,7 @@ namespace detail
|
||||
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
|
||||
if(A_trans)
|
||||
{
|
||||
array_expression tmp(A, repmat(x, 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), size4(N, M));
|
||||
array_expression tmp(A, repmat(x, 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), A.dtype(), size4(N, M));
|
||||
//Remove trans
|
||||
tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs;
|
||||
return sum(tmp, 1);
|
||||
|
@@ -335,4 +335,27 @@ void mapped_outer::postprocess(std::string &res) const
|
||||
mapped_outer::mapped_outer(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "outer"), binary_leaf(info)
|
||||
{ }
|
||||
|
||||
std::string mapped_cast::operator_to_str(operation_node_type type)
|
||||
{
|
||||
switch(type)
|
||||
{
|
||||
case OPERATOR_CAST_CHAR_TYPE : return "char";
|
||||
case OPERATOR_CAST_UCHAR_TYPE : return "uchar";
|
||||
case OPERATOR_CAST_SHORT_TYPE : return "short";
|
||||
case OPERATOR_CAST_USHORT_TYPE : return "ushort";
|
||||
case OPERATOR_CAST_INT_TYPE : return "int";
|
||||
case OPERATOR_CAST_UINT_TYPE : return "uint";
|
||||
case OPERATOR_CAST_LONG_TYPE : return "long";
|
||||
case OPERATOR_CAST_ULONG_TYPE : return "ulong";
|
||||
case OPERATOR_CAST_HALF_TYPE : return "half";
|
||||
case OPERATOR_CAST_FLOAT_TYPE : return "float";
|
||||
case OPERATOR_CAST_DOUBLE_TYPE : return "double";
|
||||
default : return "invalid";
|
||||
}
|
||||
}
|
||||
|
||||
mapped_cast::mapped_cast(operation_node_type type, unsigned int id) : mapped_object(operator_to_str(type), id, "cast")
|
||||
{ }
|
||||
|
||||
|
||||
}
|
||||
|
@@ -8,20 +8,7 @@ namespace atidlas
|
||||
namespace detail
|
||||
{
|
||||
|
||||
bool is_node_leaf(op_element const & op)
|
||||
{
|
||||
return op.type==OPERATOR_TRANS_TYPE
|
||||
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
|
||||
|| op.type==OPERATOR_MATRIX_DIAG_TYPE
|
||||
|| op.type==OPERATOR_VDIAG_TYPE
|
||||
|| op.type==OPERATOR_REPEAT_TYPE
|
||||
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|
||||
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|
||||
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|
||||
|| op.type==OPERATOR_OUTER_PROD_TYPE;
|
||||
}
|
||||
|
||||
|
||||
bool is_scalar_reduction(symbolic_expression_node const & node)
|
||||
{
|
||||
@@ -44,13 +31,50 @@ namespace detail
|
||||
|| op.type== OPERATOR_ELEMENT_PROD_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_DIV_TYPE
|
||||
|| op.type== OPERATOR_MULT_TYPE
|
||||
|| op.type== OPERATOR_DIV_TYPE;
|
||||
|| op.type== OPERATOR_DIV_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_EQ_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_NEQ_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_GREATER_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_LESS_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_GEQ_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_LEQ_TYPE ;
|
||||
}
|
||||
|
||||
|
||||
bool is_cast(op_element const & op)
|
||||
{
|
||||
return op.type== OPERATOR_CAST_CHAR_TYPE
|
||||
|| op.type== OPERATOR_CAST_UCHAR_TYPE
|
||||
|| op.type== OPERATOR_CAST_SHORT_TYPE
|
||||
|| op.type== OPERATOR_CAST_USHORT_TYPE
|
||||
|| op.type== OPERATOR_CAST_INT_TYPE
|
||||
|| op.type== OPERATOR_CAST_UINT_TYPE
|
||||
|| op.type== OPERATOR_CAST_LONG_TYPE
|
||||
|| op.type== OPERATOR_CAST_ULONG_TYPE
|
||||
|| op.type== OPERATOR_CAST_FLOAT_TYPE
|
||||
|| op.type== OPERATOR_CAST_DOUBLE_TYPE
|
||||
;
|
||||
}
|
||||
|
||||
bool is_node_leaf(op_element const & op)
|
||||
{
|
||||
return op.type==OPERATOR_TRANS_TYPE
|
||||
|| op.type==OPERATOR_MATRIX_DIAG_TYPE
|
||||
|| op.type==OPERATOR_VDIAG_TYPE
|
||||
|| op.type==OPERATOR_REPEAT_TYPE
|
||||
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|
||||
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|
||||
|| op.type==OPERATOR_OUTER_PROD_TYPE
|
||||
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
|
||||
;
|
||||
}
|
||||
|
||||
bool is_elementwise_function(op_element const & op)
|
||||
{
|
||||
return
|
||||
op.type == OPERATOR_CAST_CHAR_TYPE
|
||||
return op.type == OPERATOR_CAST_CHAR_TYPE
|
||||
|| op.type == OPERATOR_CAST_UCHAR_TYPE
|
||||
|| op.type == OPERATOR_CAST_SHORT_TYPE
|
||||
|| op.type == OPERATOR_CAST_USHORT_TYPE
|
||||
@@ -81,12 +105,6 @@ namespace detail
|
||||
|| op.type== OPERATOR_TANH_TYPE
|
||||
|
||||
|| op.type== OPERATOR_ELEMENT_POW_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_EQ_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_NEQ_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_GREATER_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_LESS_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_GEQ_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_LEQ_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_FMAX_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_FMIN_TYPE
|
||||
|| op.type== OPERATOR_ELEMENT_MAX_TYPE
|
||||
@@ -94,6 +112,8 @@ namespace detail
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
//
|
||||
filter_fun::filter_fun(pred_t pred, std::vector<size_t> & out) : pred_(pred), out_(out)
|
||||
@@ -161,18 +181,6 @@ const char * evaluate(operation_node_type type)
|
||||
case OPERATOR_TAN_TYPE : return "tan";
|
||||
case OPERATOR_TANH_TYPE : return "tanh";
|
||||
|
||||
case OPERATOR_CAST_CHAR_TYPE : return "(char)";
|
||||
case OPERATOR_CAST_UCHAR_TYPE : return "(uchar)";
|
||||
case OPERATOR_CAST_SHORT_TYPE : return "(short)";
|
||||
case OPERATOR_CAST_USHORT_TYPE : return "(ushort)";
|
||||
case OPERATOR_CAST_INT_TYPE : return "(int)";
|
||||
case OPERATOR_CAST_UINT_TYPE : return "(uint)";
|
||||
case OPERATOR_CAST_LONG_TYPE : return "(long)";
|
||||
case OPERATOR_CAST_ULONG_TYPE : return "(ulong)";
|
||||
case OPERATOR_CAST_HALF_TYPE : return "(half)";
|
||||
case OPERATOR_CAST_FLOAT_TYPE : return "(float)";
|
||||
case OPERATOR_CAST_DOUBLE_TYPE : return "(double)";
|
||||
|
||||
case OPERATOR_ELEMENT_ARGFMAX_TYPE : return "argfmax";
|
||||
case OPERATOR_ELEMENT_ARGMAX_TYPE : return "argmax";
|
||||
case OPERATOR_ELEMENT_ARGFMIN_TYPE : return "argfmin";
|
||||
@@ -193,12 +201,12 @@ const char * evaluate(operation_node_type type)
|
||||
case OPERATOR_ACCESS_TYPE : return "[]";
|
||||
|
||||
//Relational
|
||||
case OPERATOR_ELEMENT_EQ_TYPE : return "isequal";
|
||||
case OPERATOR_ELEMENT_NEQ_TYPE : return "isnotequal";
|
||||
case OPERATOR_ELEMENT_GREATER_TYPE : return "isgreater";
|
||||
case OPERATOR_ELEMENT_GEQ_TYPE : return "isgreaterequal";
|
||||
case OPERATOR_ELEMENT_LESS_TYPE : return "isless";
|
||||
case OPERATOR_ELEMENT_LEQ_TYPE : return "islessequal";
|
||||
case OPERATOR_ELEMENT_EQ_TYPE : return "==";
|
||||
case OPERATOR_ELEMENT_NEQ_TYPE : return "!=";
|
||||
case OPERATOR_ELEMENT_GREATER_TYPE : return ">";
|
||||
case OPERATOR_ELEMENT_GEQ_TYPE : return ">=";
|
||||
case OPERATOR_ELEMENT_LESS_TYPE : return "<";
|
||||
case OPERATOR_ELEMENT_LEQ_TYPE : return "<=";
|
||||
|
||||
case OPERATOR_ELEMENT_FMAX_TYPE : return "fmax";
|
||||
case OPERATOR_ELEMENT_FMIN_TYPE : return "fmin";
|
||||
@@ -261,7 +269,9 @@ evaluate_expression_traversal::evaluate_expression_traversal(std::map<std::strin
|
||||
void evaluate_expression_traversal::call_before_expansion(atidlas::symbolic_expression const & symbolic_expression, int_t root_idx) const
|
||||
{
|
||||
symbolic_expression_node const & root_node = symbolic_expression.tree()[root_idx];
|
||||
if ((root_node.op.type_family==OPERATOR_UNARY_TYPE_FAMILY || detail::is_elementwise_function(root_node.op))
|
||||
if(detail::is_cast(root_node.op))
|
||||
str_ += mapping_.at(std::make_pair(root_idx, PARENT_NODE_TYPE))->evaluate(accessors_);
|
||||
else if ((root_node.op.type_family==OPERATOR_UNARY_TYPE_FAMILY || detail::is_elementwise_function(root_node.op))
|
||||
&& !detail::is_node_leaf(root_node.op))
|
||||
str_+=evaluate(root_node.op.type);
|
||||
str_+="(";
|
||||
|
@@ -110,6 +110,8 @@ void base::map_functor::operator()(atidlas::symbolic_expression const & symbolic
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&symbolic_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&symbolic_expression, root_idx, &mapping_)));
|
||||
else if (detail::is_cast(root_node.op))
|
||||
mapping_.insert(mapping_type::value_type(key, tools::shared_ptr<mapped_object>(new mapped_cast(root_node.op.type, binder_.get(NULL)))));
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -28,7 +28,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
|
||||
kernel_generation_stream stream;
|
||||
|
||||
std::string init0, upper_bound0, inc0, init1, upper_bound1, inc1;
|
||||
|
||||
std::string data_type = append_width("#scalartype",simd_width);
|
||||
char kprefix[10];
|
||||
fill_kernel_name(kprefix, label, "d");
|
||||
|
||||
@@ -51,7 +51,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
|
||||
stream.inc_tab();
|
||||
|
||||
process(stream, PARENT_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >
|
||||
("array2", append_width("#scalartype",simd_width) + " #namereg = $VALUE{i*#stride1,j*#stride2};")
|
||||
("array2", data_type + " #namereg = $VALUE{i*#stride1,j*#stride2};")
|
||||
("vdiag", "#scalartype #namereg = ((i + ((#diag_offset<0)?#diag_offset:0))!=(j-((#diag_offset>0)?#diag_offset:0)))?0:$VALUE{min(i*#stride1, j*#stride1)};")
|
||||
("repeat", "#scalartype #namereg = $VALUE{(i%#tuplearg0)*#stride1, (j%#tuplearg1)*#stride2};")
|
||||
("outer", "#scalartype #namereg = ($LVALUE{i*#stride})*($RVALUE{j*#stride});")
|
||||
@@ -63,6 +63,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
|
||||
("repeat", "#namereg")
|
||||
("array0", "#namereg")
|
||||
("outer", "#namereg")
|
||||
("cast", "convert_"+data_type)
|
||||
, symbolic_expressions, mappings);
|
||||
|
||||
process(stream, LHS_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("array2", "$VALUE{i*#stride1,j*#stride2} = #namereg;")
|
||||
|
@@ -61,7 +61,9 @@ std::vector<std::string> vaxpy::generate_impl(unsigned int label, symbolic_expre
|
||||
("matrix_row", "#namereg")
|
||||
("matrix_column", "#namereg")
|
||||
("matrix_diag", "#namereg")
|
||||
("array0", "#namereg"), symbolic_expressions, mappings);
|
||||
("array0", "#namereg")
|
||||
("cast", "convert_"+data_type)
|
||||
, symbolic_expressions, mappings);
|
||||
|
||||
process(stream, LHS_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("array1", "#pointer[i*#stride] = #namereg;")
|
||||
("matrix_row", "$VALUE{#row, i} = #namereg;")
|
||||
@@ -82,6 +84,7 @@ std::vector<std::string> vaxpy::generate_impl(unsigned int label, symbolic_expre
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
|
||||
// std::cout << stream.str() << std::endl;
|
||||
result.push_back(stream.str());
|
||||
}
|
||||
|
||||
|
@@ -22,7 +22,7 @@ void fill(array const & a, array_infos& i)
|
||||
}
|
||||
|
||||
array_expression array_expression::operator-()
|
||||
{ return array_expression(*this, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), shape_); }
|
||||
{ return array_expression(*this, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), dtype_, shape_); }
|
||||
|
||||
|
||||
lhs_rhs_element::lhs_rhs_element()
|
||||
@@ -80,8 +80,8 @@ symbolic_expression::symbolic_expression(lhs_rhs_element const & lhs, lhs_rhs_el
|
||||
tree_(1, symbolic_expression_node(lhs, op, rhs)), root_(0), context_(context), dtype_(dtype)
|
||||
{ }
|
||||
|
||||
symbolic_expression::symbolic_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op) :
|
||||
context_(lhs.context_), dtype_(lhs.dtype_)
|
||||
symbolic_expression::symbolic_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, numeric_type const & dtype) :
|
||||
context_(lhs.context_), dtype_(dtype)
|
||||
{
|
||||
tree_.reserve(lhs.tree_.size() + 1);
|
||||
tree_.insert(tree_.end(), lhs.tree_.begin(), lhs.tree_.end());
|
||||
@@ -89,8 +89,8 @@ symbolic_expression::symbolic_expression(symbolic_expression const & lhs, lhs_rh
|
||||
root_ = tree_.size() - 1;
|
||||
}
|
||||
|
||||
symbolic_expression::symbolic_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op) :
|
||||
context_(rhs.context_), dtype_(rhs.dtype_)
|
||||
symbolic_expression::symbolic_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype) :
|
||||
context_(rhs.context_), dtype_(dtype)
|
||||
{
|
||||
tree_.reserve(rhs.tree_.size() + 1);
|
||||
tree_.insert(tree_.end(), rhs.tree_.begin(), rhs.tree_.end());
|
||||
@@ -98,8 +98,8 @@ symbolic_expression::symbolic_expression(lhs_rhs_element const & lhs, symbolic_e
|
||||
root_ = tree_.size() - 1;
|
||||
}
|
||||
|
||||
symbolic_expression::symbolic_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op):
|
||||
context_(lhs.context_), dtype_(lhs.dtype_)
|
||||
symbolic_expression::symbolic_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype):
|
||||
context_(lhs.context_), dtype_(dtype)
|
||||
{
|
||||
std::size_t lsize = lhs.tree_.size();
|
||||
std::size_t rsize = rhs.tree_.size();
|
||||
@@ -135,16 +135,16 @@ array_expression::array_expression(lhs_rhs_element const & lhs, lhs_rhs_element
|
||||
symbolic_expression(lhs, rhs, op, ctx, dtype), shape_(shape)
|
||||
{ }
|
||||
|
||||
array_expression::array_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, size4 shape):
|
||||
symbolic_expression(lhs, rhs, op), shape_(shape)
|
||||
array_expression::array_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, numeric_type const & dtype, size4 shape):
|
||||
symbolic_expression(lhs, rhs, op, dtype), shape_(shape)
|
||||
{ }
|
||||
|
||||
array_expression::array_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, size4 shape):
|
||||
symbolic_expression(lhs, rhs, op), shape_(shape)
|
||||
array_expression::array_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype, size4 shape):
|
||||
symbolic_expression(lhs, rhs, op, dtype), shape_(shape)
|
||||
{ }
|
||||
|
||||
array_expression::array_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, size4 shape):
|
||||
symbolic_expression(lhs, rhs, op), shape_(shape)
|
||||
array_expression::array_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype, size4 shape):
|
||||
symbolic_expression(lhs, rhs, op, dtype), shape_(shape)
|
||||
{ }
|
||||
|
||||
size4 array_expression::shape() const
|
||||
|
@@ -13,6 +13,7 @@ void test(T epsilon, simple_matrix_base<T> & cA, simple_matrix_base<T>& cB, simp
|
||||
using namespace std;
|
||||
|
||||
int failure_count = 0;
|
||||
ad::numeric_type dtype = C.dtype();
|
||||
cl::Context const & ctx = C.context();
|
||||
|
||||
int_t M = cC.size1();
|
||||
@@ -76,14 +77,15 @@ void test(T epsilon, simple_matrix_base<T> & cA, simple_matrix_base<T>& cB, simp
|
||||
|
||||
RUN_TEST("C = A.*B", cC(i,j) = cA(i,j)*cB(i,j), C= A*B)
|
||||
RUN_TEST("C = A./B", cC(i,j) = cA(i,j)/cB(i,j), C= A/B)
|
||||
RUN_TEST("C = A==B", cC(i,j) = cA(i,j)==cB(i,j), C= A==B)
|
||||
RUN_TEST("C = A>=B", cC(i,j) = cA(i,j)>=cB(i,j), C= A>=B)
|
||||
RUN_TEST("C = A>B", cC(i,j) = cA(i,j)>cB(i,j), C= A>B)
|
||||
RUN_TEST("C = A<=B", cC(i,j) = cA(i,j)<=cB(i,j), C= A<=B)
|
||||
RUN_TEST("C = A<B", cC(i,j) = cA(i,j)<cB(i,j), C= A<B)
|
||||
RUN_TEST("C = A!=B", cC(i,j) = cA(i,j)!=cB(i,j), C= A!=B)
|
||||
RUN_TEST("C = pow(A,B)", cC(i,j) = pow(cA(i,j), cB(i,j)), C= pow(A,B))
|
||||
|
||||
RUN_TEST("C = A==B", cC(i,j) = cA(i,j)==cB(i,j), C= cast(A==B, dtype))
|
||||
RUN_TEST("C = A>=B", cC(i,j) = cA(i,j)>=cB(i,j), C= cast(A>=B, dtype))
|
||||
RUN_TEST("C = A>B", cC(i,j) = cA(i,j)>cB(i,j), C= cast(A>B, dtype))
|
||||
RUN_TEST("C = A<=B", cC(i,j) = cA(i,j)<=cB(i,j), C= cast(A<=B, dtype))
|
||||
RUN_TEST("C = A<B", cC(i,j) = cA(i,j)<cB(i,j), C= cast(A<B, dtype))
|
||||
RUN_TEST("C = A!=B", cC(i,j) = cA(i,j)!=cB(i,j), C= cast(A!=B, dtype))
|
||||
|
||||
RUN_TEST("C = eye(M, N)", cC(i,j) = i==j, C= eye(M, N, C.dtype(), C.context()))
|
||||
RUN_TEST("C = outer(x, y)", cC(i,j) = cx[i]*cy[j], C= outer(x,y))
|
||||
|
||||
|
@@ -74,12 +74,12 @@ void test_element_wise_vector(T epsilon, simple_vector_base<T> & cx, simple_vect
|
||||
|
||||
RUN_TEST_VECTOR_AXPY("z = x.*y", cz[i] = cx[i]*cy[i], z= x*y)
|
||||
RUN_TEST_VECTOR_AXPY("z = x./y", cz[i] = cx[i]/cy[i], z= x/y)
|
||||
RUN_TEST_VECTOR_AXPY("z = x==y", cz[i] = cx[i]==cy[i], z= x==y)
|
||||
RUN_TEST_VECTOR_AXPY("z = x>=y", cz[i] = cx[i]>=cy[i], z= x>=y)
|
||||
RUN_TEST_VECTOR_AXPY("z = x>y", cz[i] = cx[i]>cy[i], z= x>y)
|
||||
RUN_TEST_VECTOR_AXPY("z = x<=y", cz[i] = cx[i]<=cy[i], z= x<=y)
|
||||
RUN_TEST_VECTOR_AXPY("z = x<y", cz[i] = cx[i]<cy[i], z= x<y)
|
||||
RUN_TEST_VECTOR_AXPY("z = x!=y", cz[i] = cx[i]!=cy[i], z= x!=y)
|
||||
RUN_TEST_VECTOR_AXPY("z = x==y", cz[i] = cx[i]==cy[i], z= cast(x==y, dtype))
|
||||
RUN_TEST_VECTOR_AXPY("z = x>=y", cz[i] = cx[i]>=cy[i], z= cast(x>=y, dtype))
|
||||
RUN_TEST_VECTOR_AXPY("z = x>y", cz[i] = cx[i]>cy[i], z= cast(x>y, dtype))
|
||||
RUN_TEST_VECTOR_AXPY("z = x<=y", cz[i] = cx[i]<=cy[i], z= cast(x<=y, dtype))
|
||||
RUN_TEST_VECTOR_AXPY("z = x<y", cz[i] = cx[i]<cy[i], z= cast(x<y, dtype))
|
||||
RUN_TEST_VECTOR_AXPY("z = x!=y", cz[i] = cx[i]!=cy[i], z= cast(x!=y, dtype))
|
||||
RUN_TEST_VECTOR_AXPY("z = pow(x,y)", cz[i] = pow(cx[i], cy[i]), z= pow(x,y))
|
||||
|
||||
#undef RUN_TEST_VECTOR_AXPY
|
||||
|
Reference in New Issue
Block a user