Bugfix in cast and relational operators

This commit is contained in:
Philippe Tillet
2015-01-29 01:00:50 -05:00
parent c7665021d1
commit d4629ba018
13 changed files with 198 additions and 125 deletions

View File

@@ -2,8 +2,8 @@
#define ATIDLAS_ARRAY_H_ #define ATIDLAS_ARRAY_H_
#include <iostream> #include <iostream>
#include "atidlas/types.h"
#include <CL/cl.hpp> #include <CL/cl.hpp>
#include "atidlas/types.h"
#include "atidlas/cl/queues.h" #include "atidlas/cl/queues.h"
#include "atidlas/symbolic/expression.h" #include "atidlas/symbolic/expression.h"
@@ -118,7 +118,6 @@ public:
//copy //copy
void copy(void const * data, array & gx, cl::CommandQueue & queue, bool blocking = true); void copy(void const * data, array & gx, cl::CommandQueue & queue, bool blocking = true);
void copy(array const & gx, void* data, cl::CommandQueue & queue, bool blocking = true); void copy(array const & gx, void* data, cl::CommandQueue & queue, bool blocking = true);
void copy(void const *data, array &gx, bool blocking = true); void copy(void const *data, array &gx, bool blocking = true);
@@ -192,6 +191,9 @@ ATIDLAS_DECLARE_UNARY_OPERATOR(tan)
ATIDLAS_DECLARE_UNARY_OPERATOR(tanh) ATIDLAS_DECLARE_UNARY_OPERATOR(tanh)
ATIDLAS_DECLARE_UNARY_OPERATOR(trans) ATIDLAS_DECLARE_UNARY_OPERATOR(trans)
array_expression cast(array const &, numeric_type dtype);
array_expression cast(array_expression const &, numeric_type dtype);
array_expression norm(array const &, unsigned int order = 2); array_expression norm(array const &, unsigned int order = 2);
array_expression norm(array_expression const &, unsigned int order = 2); array_expression norm(array_expression const &, unsigned int order = 2);

View File

@@ -235,6 +235,12 @@ public:
mapped_outer(std::string const & scalartype, unsigned int id, node_info info); mapped_outer(std::string const & scalartype, unsigned int id, node_info info);
}; };
class mapped_cast : public mapped_object
{
static std::string operator_to_str(operation_node_type type);
public:
mapped_cast(operation_node_type type, unsigned int id);
};
} }
#endif #endif

View File

@@ -17,6 +17,8 @@ namespace detail
bool is_vector_reduction(symbolic_expression_node const & node); bool is_vector_reduction(symbolic_expression_node const & node);
bool is_elementwise_operator(op_element const & op); bool is_elementwise_operator(op_element const & op);
bool is_elementwise_function(op_element const & op); bool is_elementwise_function(op_element const & op);
bool is_cast(op_element const & op);
} }
class scalar; class scalar;

View File

@@ -193,9 +193,9 @@ public:
typedef std::vector<value_type> container_type; typedef std::vector<value_type> container_type;
symbolic_expression(lhs_rhs_element const & lhs, lhs_rhs_element const & rhs, op_element const & op, cl::Context const & context, numeric_type const & dtype); symbolic_expression(lhs_rhs_element const & lhs, lhs_rhs_element const & rhs, op_element const & op, cl::Context const & context, numeric_type const & dtype);
symbolic_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op); symbolic_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, numeric_type const & dtype);
symbolic_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op); symbolic_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype);
symbolic_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op); symbolic_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype);
container_type & tree(); container_type & tree();
container_type const & tree() const; container_type const & tree() const;
@@ -212,9 +212,9 @@ protected:
struct array_expression: public symbolic_expression struct array_expression: public symbolic_expression
{ {
array_expression(lhs_rhs_element const & lhs, lhs_rhs_element const & rhs, op_element const & op, cl::Context const & ctx, numeric_type const & dtype, size4 shape); array_expression(lhs_rhs_element const & lhs, lhs_rhs_element const & rhs, op_element const & op, cl::Context const & ctx, numeric_type const & dtype, size4 shape);
array_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, size4 shape); array_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, numeric_type const & dtype, size4 shape);
array_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, size4 shape); array_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype, size4 shape);
array_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, size4 shape); array_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype, size4 shape);
size4 shape() const; size4 shape() const;
array_expression& reshape(int_t size1, int_t size2=1); array_expression& reshape(int_t size1, int_t size2=1);
int_t nshape() const; int_t nshape() const;

View File

@@ -132,6 +132,7 @@ int_t array::dsize() const
//--------------------------------------- //---------------------------------------
array & array::operator=(array const & rhs) array & array::operator=(array const & rhs)
{ {
assert(dtype_ == rhs.dtype());
array_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_); array_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
cl::CommandQueue & queue = cl_ext::get_queue(context_, 0); cl::CommandQueue & queue = cl_ext::get_queue(context_, 0);
model_map_t & mmap = atidlas::get_model_map(queue); model_map_t & mmap = atidlas::get_model_map(queue);
@@ -141,7 +142,8 @@ array & array::operator=(array const & rhs)
array & array::operator=(array_expression const & rhs) array & array::operator=(array_expression const & rhs)
{ {
array_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), shape_); assert(dtype_ == rhs.dtype());
array_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), dtype_, shape_);
cl::CommandQueue & queue = cl_ext::get_queue(context_, 0); cl::CommandQueue & queue = cl_ext::get_queue(context_, 0);
model_map_t & mmap = atidlas::get_model_map(queue); model_map_t & mmap = atidlas::get_model_map(queue);
execute(expression, mmap); execute(expression, mmap);
@@ -181,7 +183,7 @@ array & array::operator+=(array const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); } { return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
array & array::operator+=(array_expression const & rhs) array & array::operator+=(array_expression const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), shape_); } { return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), dtype_, shape_); }
//---- //----
array & array::operator-=(value_scalar const & rhs) array & array::operator-=(value_scalar const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); } { return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
@@ -190,7 +192,7 @@ array & array::operator-=(array const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); } { return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
array & array::operator-=(array_expression const & rhs) array & array::operator-=(array_expression const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), shape_); } { return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), dtype_, shape_); }
//---- //----
array & array::operator*=(value_scalar const & rhs) array & array::operator*=(value_scalar const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); } { return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
@@ -199,7 +201,7 @@ array & array::operator*=(array const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); } { return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
array & array::operator*=(array_expression const & rhs) array & array::operator*=(array_expression const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), shape_); } { return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), dtype_, shape_); }
//---- //----
array & array::operator/=(value_scalar const & rhs) array & array::operator/=(value_scalar const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); } { return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
@@ -208,7 +210,7 @@ array & array::operator/=(array const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); } { return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
array & array::operator/=(array_expression const & rhs) array & array::operator/=(array_expression const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), shape_); } { return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), dtype_, shape_); }
array_expression array::T() const array_expression array::T() const
{ return atidlas::trans(*this) ;} { return atidlas::trans(*this) ;}
@@ -385,51 +387,55 @@ bool check_elementwise(U const & u, V const & v)
return max(u.shape())==1 || max(v.shape())==1 || u.shape()==v.shape(); return max(u.shape())==1 || max(v.shape())==1 || u.shape()==v.shape();
} }
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME) \
array_expression OPNAME (array_expression const & x, array_expression const & y) \ array_expression OPNAME (array_expression const & x, array_expression const & y) \
{ assert(check_elementwise(x, y));\ { assert(check_elementwise(x, y));\
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), elementwise_size(x, y) ); } \ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, elementwise_size(x, y)); } \
\ \
array_expression OPNAME (array const & x, array_expression const & y) \ array_expression OPNAME (array const & x, array_expression const & y) \
{ assert(check_elementwise(x, y));\ { assert(check_elementwise(x, y));\
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), elementwise_size(x, y)); } \ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, elementwise_size(x, y)); } \
\ \
array_expression OPNAME (array_expression const & x, array const & y) \ array_expression OPNAME (array_expression const & x, array const & y) \
{ assert(check_elementwise(x, y));\ { assert(check_elementwise(x, y));\
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), elementwise_size(x, y)); } \ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, elementwise_size(x, y)); } \
\ \
array_expression OPNAME (array const & x, array const & y) \ array_expression OPNAME (array const & x, array const & y) \
{ assert(check_elementwise(x, y));\ { assert(check_elementwise(x, y));\
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), x.dtype(), elementwise_size(x, y)); }\ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, elementwise_size(x, y)); }\
\ \
array_expression OPNAME (array_expression const & x, value_scalar const & y) \ array_expression OPNAME (array_expression const & x, value_scalar const & y) \
{ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.shape()); } \ { return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, x.shape()); } \
\ \
array_expression OPNAME (array const & x, value_scalar const & y) \ array_expression OPNAME (array const & x, value_scalar const & y) \
{ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\ { return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\ \
array_expression OPNAME (value_scalar const & y, array_expression const & x) \ array_expression OPNAME (value_scalar const & y, array_expression const & x) \
{ return array_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.shape()); } \ { return array_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, x.shape()); } \
\ \
array_expression OPNAME (value_scalar const & y, array const & x) \ array_expression OPNAME (value_scalar const & y, array const & x) \
{ return array_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); } { return array_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ADD_TYPE, operator +) DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ADD_TYPE, operator +, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_SUB_TYPE, operator -) DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_SUB_TYPE, operator -, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_MULT_TYPE, operator *) DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_MULT_TYPE, operator *, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_DIV_TYPE, operator /) DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_DIV_TYPE, operator /, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GREATER_TYPE, operator >) DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, maximum, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GEQ_TYPE, operator >=) DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, minimum, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LESS_TYPE, operator <) DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LEQ_TYPE, operator <=)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==) namespace detail
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=) { DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ASSIGN_TYPE, assign, x.dtype()) }
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GREATER_TYPE, operator >, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GEQ_TYPE, operator >=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LESS_TYPE, operator <, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LEQ_TYPE, operator <=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, maximum)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, minimum)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow)
array_expression outer(array const & x, array const & y) array_expression outer(array const & x, array const & y)
{ {
@@ -437,10 +443,6 @@ array_expression outer(array const & x, array const & y)
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_OUTER_PROD_TYPE), x.context(), x.dtype(), size4(max(x.shape()), max(y.shape())) ); return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_OUTER_PROD_TYPE), x.context(), x.dtype(), size4(max(x.shape()), max(y.shape())) );
} }
namespace detail
{
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ASSIGN_TYPE, assign)
}
#undef DEFINE_ELEMENT_BINARY_OPERATOR #undef DEFINE_ELEMENT_BINARY_OPERATOR
//--------------------------------------- //---------------------------------------
@@ -452,7 +454,7 @@ array_expression OPNAME (array const & x) \
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\ { return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
\ \
array_expression OPNAME (array_expression const & x) \ array_expression OPNAME (array_expression const & x) \
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.shape()); } { return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.dtype(), x.shape()); }
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?OPERATOR_FABS_TYPE:OPERATOR_ABS_TYPE, abs) DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?OPERATOR_FABS_TYPE:OPERATOR_ABS_TYPE, abs)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ACOS_TYPE, acos) DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ACOS_TYPE, acos)
@@ -476,15 +478,35 @@ DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_TANH_TYPE, tanh)
///*--- Misc----*/ ///*--- Misc----*/
////--------------------------------------- ////---------------------------------------
atidlas::array_expression eye(std::size_t M, std::size_t N, atidlas::numeric_type dtype, cl::Context ctx) inline operation_node_type casted(numeric_type dtype)
{ {
return array_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N)); switch(dtype)
{
case CHAR_TYPE: return OPERATOR_CAST_CHAR_TYPE;
case UCHAR_TYPE: return OPERATOR_CAST_UCHAR_TYPE;
case SHORT_TYPE: return OPERATOR_CAST_SHORT_TYPE;
case USHORT_TYPE: return OPERATOR_CAST_USHORT_TYPE;
case INT_TYPE: return OPERATOR_CAST_INT_TYPE;
case UINT_TYPE: return OPERATOR_CAST_UINT_TYPE;
case LONG_TYPE: return OPERATOR_CAST_LONG_TYPE;
case ULONG_TYPE: return OPERATOR_CAST_ULONG_TYPE;
case FLOAT_TYPE: return OPERATOR_CAST_FLOAT_TYPE;
case DOUBLE_TYPE: return OPERATOR_CAST_DOUBLE_TYPE;
default: throw unknown_datatype(dtype);
}
} }
array_expression cast(array const & x, numeric_type dtype)
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
array_expression cast(array_expression const & x, numeric_type dtype)
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), dtype, x.shape()); }
atidlas::array_expression eye(std::size_t M, std::size_t N, atidlas::numeric_type dtype, cl::Context ctx)
{ return array_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N)); }
atidlas::array_expression zeros(std::size_t M, std::size_t N, atidlas::numeric_type dtype, cl::Context ctx) atidlas::array_expression zeros(std::size_t M, std::size_t N, atidlas::numeric_type dtype, cl::Context ctx)
{ { return array_expression(value_scalar(0), lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N)); }
return array_expression(value_scalar(0), lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N));
}
inline size4 flip(size4 const & shape) inline size4 flip(size4 const & shape)
{ return size4(shape._2, shape._1);} { return size4(shape._2, shape._1);}
@@ -496,7 +518,7 @@ array_expression trans(array const & x) \
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\ { return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
\ \
array_expression trans(array_expression const & x) \ array_expression trans(array_expression const & x) \
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), flip(x.shape())); } { return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.dtype(), flip(x.shape())); }
array_expression repmat(array const & A, int_t const & rep1, int_t const & rep2) array_expression repmat(array const & A, int_t const & rep1, int_t const & rep2)
{ {
@@ -515,7 +537,7 @@ array_expression repmat(array_expression const & A, int_t const & rep1, int_t co
infos.rep2 = rep2; infos.rep2 = rep2;
infos.sub1 = A.shape()._1; infos.sub1 = A.shape()._1;
infos.sub2 = A.shape()._2; infos.sub2 = A.shape()._2;
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), size4(infos.rep1*infos.sub1, infos.rep2*infos.sub2)); return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.dtype(), size4(infos.rep1*infos.sub1, infos.rep2*infos.sub2));
} }
////--------------------------------------- ////---------------------------------------
@@ -540,11 +562,11 @@ array_expression OPNAME(array_expression const & x, int_t axis)\
if(axis < -1 || axis > x.nshape())\ if(axis < -1 || axis > x.nshape())\
throw std::out_of_range("The axis entry is out of bounds");\ throw std::out_of_range("The axis entry is out of bounds");\
if(axis==-1)\ if(axis==-1)\
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY, OP), size4(1));\ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY, OP), x.dtype(), size4(1));\
else if(axis==0)\ else if(axis==0)\
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_ROWS_REDUCTION_TYPE_FAMILY, OP), size4(x.shape()._1));\ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_ROWS_REDUCTION_TYPE_FAMILY, OP), x.dtype(), size4(x.shape()._1));\
else\ else\
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OP), size4(x.shape()._2));\ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OP), x.dtype(), size4(x.shape()._2));\
} }
DEFINE_REDUCTION(OPERATOR_ADD_TYPE, sum) DEFINE_REDUCTION(OPERATOR_ADD_TYPE, sum)
@@ -576,7 +598,7 @@ namespace detail
shape._1 = A.shape()._2; shape._1 = A.shape()._2;
} }
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), shape); array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.dtype(), shape);
symbolic_expression_node & res_root = const_cast<symbolic_expression_node &>(res.tree()[res.root()]); symbolic_expression_node & res_root = const_cast<symbolic_expression_node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs; if(A_trans) res_root.lhs = A_root.lhs;
return res; return res;
@@ -593,7 +615,7 @@ namespace detail
type = OPERATOR_MATRIX_PRODUCT_NT_TYPE; type = OPERATOR_MATRIX_PRODUCT_NT_TYPE;
shape._2 = B.shape()._1; shape._2 = B.shape()._1;
} }
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), shape); array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.dtype(), shape);
symbolic_expression_node & res_root = const_cast<symbolic_expression_node &>(res.tree()[res.root()]); symbolic_expression_node & res_root = const_cast<symbolic_expression_node &>(res.tree()[res.root()]);
if(B_trans) res_root.rhs = B_root.lhs; if(B_trans) res_root.rhs = B_root.lhs;
return res; return res;
@@ -615,7 +637,7 @@ namespace detail
else if(!A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_NT_TYPE; else if(!A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_NT_TYPE;
else type = OPERATOR_MATRIX_PRODUCT_NN_TYPE; else type = OPERATOR_MATRIX_PRODUCT_NN_TYPE;
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), shape); array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.dtype(), shape);
symbolic_expression_node & res_root = const_cast<symbolic_expression_node &>(res.tree()[res.root()]); symbolic_expression_node & res_root = const_cast<symbolic_expression_node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs; if(A_trans) res_root.lhs = A_root.lhs;
if(B_trans) res_root.rhs = B_root.lhs; if(B_trans) res_root.rhs = B_root.lhs;
@@ -639,7 +661,7 @@ namespace detail
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE; bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
if(A_trans) if(A_trans)
{ {
array_expression tmp(A, repmat(x, 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), size4(N, M)); array_expression tmp(A, repmat(x, 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), A.dtype(), size4(N, M));
//Remove trans //Remove trans
tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs; tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs;
return sum(tmp, 1); return sum(tmp, 1);

View File

@@ -335,4 +335,27 @@ void mapped_outer::postprocess(std::string &res) const
mapped_outer::mapped_outer(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "outer"), binary_leaf(info) mapped_outer::mapped_outer(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "outer"), binary_leaf(info)
{ } { }
std::string mapped_cast::operator_to_str(operation_node_type type)
{
switch(type)
{
case OPERATOR_CAST_CHAR_TYPE : return "char";
case OPERATOR_CAST_UCHAR_TYPE : return "uchar";
case OPERATOR_CAST_SHORT_TYPE : return "short";
case OPERATOR_CAST_USHORT_TYPE : return "ushort";
case OPERATOR_CAST_INT_TYPE : return "int";
case OPERATOR_CAST_UINT_TYPE : return "uint";
case OPERATOR_CAST_LONG_TYPE : return "long";
case OPERATOR_CAST_ULONG_TYPE : return "ulong";
case OPERATOR_CAST_HALF_TYPE : return "half";
case OPERATOR_CAST_FLOAT_TYPE : return "float";
case OPERATOR_CAST_DOUBLE_TYPE : return "double";
default : return "invalid";
}
}
mapped_cast::mapped_cast(operation_node_type type, unsigned int id) : mapped_object(operator_to_str(type), id, "cast")
{ }
} }

View File

@@ -8,20 +8,7 @@ namespace atidlas
namespace detail namespace detail
{ {
bool is_node_leaf(op_element const & op)
{
return op.type==OPERATOR_TRANS_TYPE
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
|| op.type==OPERATOR_MATRIX_DIAG_TYPE
|| op.type==OPERATOR_VDIAG_TYPE
|| op.type==OPERATOR_REPEAT_TYPE
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|| op.type==OPERATOR_OUTER_PROD_TYPE;
}
bool is_scalar_reduction(symbolic_expression_node const & node) bool is_scalar_reduction(symbolic_expression_node const & node)
{ {
@@ -44,13 +31,50 @@ namespace detail
|| op.type== OPERATOR_ELEMENT_PROD_TYPE || op.type== OPERATOR_ELEMENT_PROD_TYPE
|| op.type== OPERATOR_ELEMENT_DIV_TYPE || op.type== OPERATOR_ELEMENT_DIV_TYPE
|| op.type== OPERATOR_MULT_TYPE || op.type== OPERATOR_MULT_TYPE
|| op.type== OPERATOR_DIV_TYPE; || op.type== OPERATOR_DIV_TYPE
|| op.type== OPERATOR_ELEMENT_EQ_TYPE
|| op.type== OPERATOR_ELEMENT_NEQ_TYPE
|| op.type== OPERATOR_ELEMENT_GREATER_TYPE
|| op.type== OPERATOR_ELEMENT_LESS_TYPE
|| op.type== OPERATOR_ELEMENT_GEQ_TYPE
|| op.type== OPERATOR_ELEMENT_LEQ_TYPE ;
}
bool is_cast(op_element const & op)
{
return op.type== OPERATOR_CAST_CHAR_TYPE
|| op.type== OPERATOR_CAST_UCHAR_TYPE
|| op.type== OPERATOR_CAST_SHORT_TYPE
|| op.type== OPERATOR_CAST_USHORT_TYPE
|| op.type== OPERATOR_CAST_INT_TYPE
|| op.type== OPERATOR_CAST_UINT_TYPE
|| op.type== OPERATOR_CAST_LONG_TYPE
|| op.type== OPERATOR_CAST_ULONG_TYPE
|| op.type== OPERATOR_CAST_FLOAT_TYPE
|| op.type== OPERATOR_CAST_DOUBLE_TYPE
;
}
bool is_node_leaf(op_element const & op)
{
return op.type==OPERATOR_TRANS_TYPE
|| op.type==OPERATOR_MATRIX_DIAG_TYPE
|| op.type==OPERATOR_VDIAG_TYPE
|| op.type==OPERATOR_REPEAT_TYPE
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|| op.type==OPERATOR_OUTER_PROD_TYPE
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
;
} }
bool is_elementwise_function(op_element const & op) bool is_elementwise_function(op_element const & op)
{ {
return return op.type == OPERATOR_CAST_CHAR_TYPE
op.type == OPERATOR_CAST_CHAR_TYPE
|| op.type == OPERATOR_CAST_UCHAR_TYPE || op.type == OPERATOR_CAST_UCHAR_TYPE
|| op.type == OPERATOR_CAST_SHORT_TYPE || op.type == OPERATOR_CAST_SHORT_TYPE
|| op.type == OPERATOR_CAST_USHORT_TYPE || op.type == OPERATOR_CAST_USHORT_TYPE
@@ -81,12 +105,6 @@ namespace detail
|| op.type== OPERATOR_TANH_TYPE || op.type== OPERATOR_TANH_TYPE
|| op.type== OPERATOR_ELEMENT_POW_TYPE || op.type== OPERATOR_ELEMENT_POW_TYPE
|| op.type== OPERATOR_ELEMENT_EQ_TYPE
|| op.type== OPERATOR_ELEMENT_NEQ_TYPE
|| op.type== OPERATOR_ELEMENT_GREATER_TYPE
|| op.type== OPERATOR_ELEMENT_LESS_TYPE
|| op.type== OPERATOR_ELEMENT_GEQ_TYPE
|| op.type== OPERATOR_ELEMENT_LEQ_TYPE
|| op.type== OPERATOR_ELEMENT_FMAX_TYPE || op.type== OPERATOR_ELEMENT_FMAX_TYPE
|| op.type== OPERATOR_ELEMENT_FMIN_TYPE || op.type== OPERATOR_ELEMENT_FMIN_TYPE
|| op.type== OPERATOR_ELEMENT_MAX_TYPE || op.type== OPERATOR_ELEMENT_MAX_TYPE
@@ -94,6 +112,8 @@ namespace detail
} }
} }
// //
filter_fun::filter_fun(pred_t pred, std::vector<size_t> & out) : pred_(pred), out_(out) filter_fun::filter_fun(pred_t pred, std::vector<size_t> & out) : pred_(pred), out_(out)
@@ -161,18 +181,6 @@ const char * evaluate(operation_node_type type)
case OPERATOR_TAN_TYPE : return "tan"; case OPERATOR_TAN_TYPE : return "tan";
case OPERATOR_TANH_TYPE : return "tanh"; case OPERATOR_TANH_TYPE : return "tanh";
case OPERATOR_CAST_CHAR_TYPE : return "(char)";
case OPERATOR_CAST_UCHAR_TYPE : return "(uchar)";
case OPERATOR_CAST_SHORT_TYPE : return "(short)";
case OPERATOR_CAST_USHORT_TYPE : return "(ushort)";
case OPERATOR_CAST_INT_TYPE : return "(int)";
case OPERATOR_CAST_UINT_TYPE : return "(uint)";
case OPERATOR_CAST_LONG_TYPE : return "(long)";
case OPERATOR_CAST_ULONG_TYPE : return "(ulong)";
case OPERATOR_CAST_HALF_TYPE : return "(half)";
case OPERATOR_CAST_FLOAT_TYPE : return "(float)";
case OPERATOR_CAST_DOUBLE_TYPE : return "(double)";
case OPERATOR_ELEMENT_ARGFMAX_TYPE : return "argfmax"; case OPERATOR_ELEMENT_ARGFMAX_TYPE : return "argfmax";
case OPERATOR_ELEMENT_ARGMAX_TYPE : return "argmax"; case OPERATOR_ELEMENT_ARGMAX_TYPE : return "argmax";
case OPERATOR_ELEMENT_ARGFMIN_TYPE : return "argfmin"; case OPERATOR_ELEMENT_ARGFMIN_TYPE : return "argfmin";
@@ -193,12 +201,12 @@ const char * evaluate(operation_node_type type)
case OPERATOR_ACCESS_TYPE : return "[]"; case OPERATOR_ACCESS_TYPE : return "[]";
//Relational //Relational
case OPERATOR_ELEMENT_EQ_TYPE : return "isequal"; case OPERATOR_ELEMENT_EQ_TYPE : return "==";
case OPERATOR_ELEMENT_NEQ_TYPE : return "isnotequal"; case OPERATOR_ELEMENT_NEQ_TYPE : return "!=";
case OPERATOR_ELEMENT_GREATER_TYPE : return "isgreater"; case OPERATOR_ELEMENT_GREATER_TYPE : return ">";
case OPERATOR_ELEMENT_GEQ_TYPE : return "isgreaterequal"; case OPERATOR_ELEMENT_GEQ_TYPE : return ">=";
case OPERATOR_ELEMENT_LESS_TYPE : return "isless"; case OPERATOR_ELEMENT_LESS_TYPE : return "<";
case OPERATOR_ELEMENT_LEQ_TYPE : return "islessequal"; case OPERATOR_ELEMENT_LEQ_TYPE : return "<=";
case OPERATOR_ELEMENT_FMAX_TYPE : return "fmax"; case OPERATOR_ELEMENT_FMAX_TYPE : return "fmax";
case OPERATOR_ELEMENT_FMIN_TYPE : return "fmin"; case OPERATOR_ELEMENT_FMIN_TYPE : return "fmin";
@@ -261,7 +269,9 @@ evaluate_expression_traversal::evaluate_expression_traversal(std::map<std::strin
void evaluate_expression_traversal::call_before_expansion(atidlas::symbolic_expression const & symbolic_expression, int_t root_idx) const void evaluate_expression_traversal::call_before_expansion(atidlas::symbolic_expression const & symbolic_expression, int_t root_idx) const
{ {
symbolic_expression_node const & root_node = symbolic_expression.tree()[root_idx]; symbolic_expression_node const & root_node = symbolic_expression.tree()[root_idx];
if ((root_node.op.type_family==OPERATOR_UNARY_TYPE_FAMILY || detail::is_elementwise_function(root_node.op)) if(detail::is_cast(root_node.op))
str_ += mapping_.at(std::make_pair(root_idx, PARENT_NODE_TYPE))->evaluate(accessors_);
else if ((root_node.op.type_family==OPERATOR_UNARY_TYPE_FAMILY || detail::is_elementwise_function(root_node.op))
&& !detail::is_node_leaf(root_node.op)) && !detail::is_node_leaf(root_node.op))
str_+=evaluate(root_node.op.type); str_+=evaluate(root_node.op.type);
str_+="("; str_+="(";

View File

@@ -110,6 +110,8 @@ void base::map_functor::operator()(atidlas::symbolic_expression const & symbolic
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&symbolic_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&symbolic_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE) else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&symbolic_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&symbolic_expression, root_idx, &mapping_)));
else if (detail::is_cast(root_node.op))
mapping_.insert(mapping_type::value_type(key, tools::shared_ptr<mapped_object>(new mapped_cast(root_node.op.type, binder_.get(NULL)))));
} }
} }

View File

@@ -28,7 +28,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
kernel_generation_stream stream; kernel_generation_stream stream;
std::string init0, upper_bound0, inc0, init1, upper_bound1, inc1; std::string init0, upper_bound0, inc0, init1, upper_bound1, inc1;
std::string data_type = append_width("#scalartype",simd_width);
char kprefix[10]; char kprefix[10];
fill_kernel_name(kprefix, label, "d"); fill_kernel_name(kprefix, label, "d");
@@ -51,7 +51,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
stream.inc_tab(); stream.inc_tab();
process(stream, PARENT_NODE_TYPE, tools::make_map<std::map<std::string, std::string> > process(stream, PARENT_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >
("array2", append_width("#scalartype",simd_width) + " #namereg = $VALUE{i*#stride1,j*#stride2};") ("array2", data_type + " #namereg = $VALUE{i*#stride1,j*#stride2};")
("vdiag", "#scalartype #namereg = ((i + ((#diag_offset<0)?#diag_offset:0))!=(j-((#diag_offset>0)?#diag_offset:0)))?0:$VALUE{min(i*#stride1, j*#stride1)};") ("vdiag", "#scalartype #namereg = ((i + ((#diag_offset<0)?#diag_offset:0))!=(j-((#diag_offset>0)?#diag_offset:0)))?0:$VALUE{min(i*#stride1, j*#stride1)};")
("repeat", "#scalartype #namereg = $VALUE{(i%#tuplearg0)*#stride1, (j%#tuplearg1)*#stride2};") ("repeat", "#scalartype #namereg = $VALUE{(i%#tuplearg0)*#stride1, (j%#tuplearg1)*#stride2};")
("outer", "#scalartype #namereg = ($LVALUE{i*#stride})*($RVALUE{j*#stride});") ("outer", "#scalartype #namereg = ($LVALUE{i*#stride})*($RVALUE{j*#stride});")
@@ -63,6 +63,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
("repeat", "#namereg") ("repeat", "#namereg")
("array0", "#namereg") ("array0", "#namereg")
("outer", "#namereg") ("outer", "#namereg")
("cast", "convert_"+data_type)
, symbolic_expressions, mappings); , symbolic_expressions, mappings);
process(stream, LHS_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("array2", "$VALUE{i*#stride1,j*#stride2} = #namereg;") process(stream, LHS_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("array2", "$VALUE{i*#stride1,j*#stride2} = #namereg;")

View File

@@ -61,7 +61,9 @@ std::vector<std::string> vaxpy::generate_impl(unsigned int label, symbolic_expre
("matrix_row", "#namereg") ("matrix_row", "#namereg")
("matrix_column", "#namereg") ("matrix_column", "#namereg")
("matrix_diag", "#namereg") ("matrix_diag", "#namereg")
("array0", "#namereg"), symbolic_expressions, mappings); ("array0", "#namereg")
("cast", "convert_"+data_type)
, symbolic_expressions, mappings);
process(stream, LHS_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("array1", "#pointer[i*#stride] = #namereg;") process(stream, LHS_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("array1", "#pointer[i*#stride] = #namereg;")
("matrix_row", "$VALUE{#row, i} = #namereg;") ("matrix_row", "$VALUE{#row, i} = #namereg;")
@@ -82,6 +84,7 @@ std::vector<std::string> vaxpy::generate_impl(unsigned int label, symbolic_expre
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
// std::cout << stream.str() << std::endl;
result.push_back(stream.str()); result.push_back(stream.str());
} }

View File

@@ -22,7 +22,7 @@ void fill(array const & a, array_infos& i)
} }
array_expression array_expression::operator-() array_expression array_expression::operator-()
{ return array_expression(*this, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), shape_); } { return array_expression(*this, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), dtype_, shape_); }
lhs_rhs_element::lhs_rhs_element() lhs_rhs_element::lhs_rhs_element()
@@ -80,8 +80,8 @@ symbolic_expression::symbolic_expression(lhs_rhs_element const & lhs, lhs_rhs_el
tree_(1, symbolic_expression_node(lhs, op, rhs)), root_(0), context_(context), dtype_(dtype) tree_(1, symbolic_expression_node(lhs, op, rhs)), root_(0), context_(context), dtype_(dtype)
{ } { }
symbolic_expression::symbolic_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op) : symbolic_expression::symbolic_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, numeric_type const & dtype) :
context_(lhs.context_), dtype_(lhs.dtype_) context_(lhs.context_), dtype_(dtype)
{ {
tree_.reserve(lhs.tree_.size() + 1); tree_.reserve(lhs.tree_.size() + 1);
tree_.insert(tree_.end(), lhs.tree_.begin(), lhs.tree_.end()); tree_.insert(tree_.end(), lhs.tree_.begin(), lhs.tree_.end());
@@ -89,8 +89,8 @@ symbolic_expression::symbolic_expression(symbolic_expression const & lhs, lhs_rh
root_ = tree_.size() - 1; root_ = tree_.size() - 1;
} }
symbolic_expression::symbolic_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op) : symbolic_expression::symbolic_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype) :
context_(rhs.context_), dtype_(rhs.dtype_) context_(rhs.context_), dtype_(dtype)
{ {
tree_.reserve(rhs.tree_.size() + 1); tree_.reserve(rhs.tree_.size() + 1);
tree_.insert(tree_.end(), rhs.tree_.begin(), rhs.tree_.end()); tree_.insert(tree_.end(), rhs.tree_.begin(), rhs.tree_.end());
@@ -98,8 +98,8 @@ symbolic_expression::symbolic_expression(lhs_rhs_element const & lhs, symbolic_e
root_ = tree_.size() - 1; root_ = tree_.size() - 1;
} }
symbolic_expression::symbolic_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op): symbolic_expression::symbolic_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype):
context_(lhs.context_), dtype_(lhs.dtype_) context_(lhs.context_), dtype_(dtype)
{ {
std::size_t lsize = lhs.tree_.size(); std::size_t lsize = lhs.tree_.size();
std::size_t rsize = rhs.tree_.size(); std::size_t rsize = rhs.tree_.size();
@@ -135,16 +135,16 @@ array_expression::array_expression(lhs_rhs_element const & lhs, lhs_rhs_element
symbolic_expression(lhs, rhs, op, ctx, dtype), shape_(shape) symbolic_expression(lhs, rhs, op, ctx, dtype), shape_(shape)
{ } { }
array_expression::array_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, size4 shape): array_expression::array_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, numeric_type const & dtype, size4 shape):
symbolic_expression(lhs, rhs, op), shape_(shape) symbolic_expression(lhs, rhs, op, dtype), shape_(shape)
{ } { }
array_expression::array_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, size4 shape): array_expression::array_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype, size4 shape):
symbolic_expression(lhs, rhs, op), shape_(shape) symbolic_expression(lhs, rhs, op, dtype), shape_(shape)
{ } { }
array_expression::array_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, size4 shape): array_expression::array_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype, size4 shape):
symbolic_expression(lhs, rhs, op), shape_(shape) symbolic_expression(lhs, rhs, op, dtype), shape_(shape)
{ } { }
size4 array_expression::shape() const size4 array_expression::shape() const

View File

@@ -13,6 +13,7 @@ void test(T epsilon, simple_matrix_base<T> & cA, simple_matrix_base<T>& cB, simp
using namespace std; using namespace std;
int failure_count = 0; int failure_count = 0;
ad::numeric_type dtype = C.dtype();
cl::Context const & ctx = C.context(); cl::Context const & ctx = C.context();
int_t M = cC.size1(); int_t M = cC.size1();
@@ -76,14 +77,15 @@ void test(T epsilon, simple_matrix_base<T> & cA, simple_matrix_base<T>& cB, simp
RUN_TEST("C = A.*B", cC(i,j) = cA(i,j)*cB(i,j), C= A*B) RUN_TEST("C = A.*B", cC(i,j) = cA(i,j)*cB(i,j), C= A*B)
RUN_TEST("C = A./B", cC(i,j) = cA(i,j)/cB(i,j), C= A/B) RUN_TEST("C = A./B", cC(i,j) = cA(i,j)/cB(i,j), C= A/B)
RUN_TEST("C = A==B", cC(i,j) = cA(i,j)==cB(i,j), C= A==B)
RUN_TEST("C = A>=B", cC(i,j) = cA(i,j)>=cB(i,j), C= A>=B)
RUN_TEST("C = A>B", cC(i,j) = cA(i,j)>cB(i,j), C= A>B)
RUN_TEST("C = A<=B", cC(i,j) = cA(i,j)<=cB(i,j), C= A<=B)
RUN_TEST("C = A<B", cC(i,j) = cA(i,j)<cB(i,j), C= A<B)
RUN_TEST("C = A!=B", cC(i,j) = cA(i,j)!=cB(i,j), C= A!=B)
RUN_TEST("C = pow(A,B)", cC(i,j) = pow(cA(i,j), cB(i,j)), C= pow(A,B)) RUN_TEST("C = pow(A,B)", cC(i,j) = pow(cA(i,j), cB(i,j)), C= pow(A,B))
RUN_TEST("C = A==B", cC(i,j) = cA(i,j)==cB(i,j), C= cast(A==B, dtype))
RUN_TEST("C = A>=B", cC(i,j) = cA(i,j)>=cB(i,j), C= cast(A>=B, dtype))
RUN_TEST("C = A>B", cC(i,j) = cA(i,j)>cB(i,j), C= cast(A>B, dtype))
RUN_TEST("C = A<=B", cC(i,j) = cA(i,j)<=cB(i,j), C= cast(A<=B, dtype))
RUN_TEST("C = A<B", cC(i,j) = cA(i,j)<cB(i,j), C= cast(A<B, dtype))
RUN_TEST("C = A!=B", cC(i,j) = cA(i,j)!=cB(i,j), C= cast(A!=B, dtype))
RUN_TEST("C = eye(M, N)", cC(i,j) = i==j, C= eye(M, N, C.dtype(), C.context())) RUN_TEST("C = eye(M, N)", cC(i,j) = i==j, C= eye(M, N, C.dtype(), C.context()))
RUN_TEST("C = outer(x, y)", cC(i,j) = cx[i]*cy[j], C= outer(x,y)) RUN_TEST("C = outer(x, y)", cC(i,j) = cx[i]*cy[j], C= outer(x,y))

View File

@@ -74,12 +74,12 @@ void test_element_wise_vector(T epsilon, simple_vector_base<T> & cx, simple_vect
RUN_TEST_VECTOR_AXPY("z = x.*y", cz[i] = cx[i]*cy[i], z= x*y) RUN_TEST_VECTOR_AXPY("z = x.*y", cz[i] = cx[i]*cy[i], z= x*y)
RUN_TEST_VECTOR_AXPY("z = x./y", cz[i] = cx[i]/cy[i], z= x/y) RUN_TEST_VECTOR_AXPY("z = x./y", cz[i] = cx[i]/cy[i], z= x/y)
RUN_TEST_VECTOR_AXPY("z = x==y", cz[i] = cx[i]==cy[i], z= x==y) RUN_TEST_VECTOR_AXPY("z = x==y", cz[i] = cx[i]==cy[i], z= cast(x==y, dtype))
RUN_TEST_VECTOR_AXPY("z = x>=y", cz[i] = cx[i]>=cy[i], z= x>=y) RUN_TEST_VECTOR_AXPY("z = x>=y", cz[i] = cx[i]>=cy[i], z= cast(x>=y, dtype))
RUN_TEST_VECTOR_AXPY("z = x>y", cz[i] = cx[i]>cy[i], z= x>y) RUN_TEST_VECTOR_AXPY("z = x>y", cz[i] = cx[i]>cy[i], z= cast(x>y, dtype))
RUN_TEST_VECTOR_AXPY("z = x<=y", cz[i] = cx[i]<=cy[i], z= x<=y) RUN_TEST_VECTOR_AXPY("z = x<=y", cz[i] = cx[i]<=cy[i], z= cast(x<=y, dtype))
RUN_TEST_VECTOR_AXPY("z = x<y", cz[i] = cx[i]<cy[i], z= x<y) RUN_TEST_VECTOR_AXPY("z = x<y", cz[i] = cx[i]<cy[i], z= cast(x<y, dtype))
RUN_TEST_VECTOR_AXPY("z = x!=y", cz[i] = cx[i]!=cy[i], z= x!=y) RUN_TEST_VECTOR_AXPY("z = x!=y", cz[i] = cx[i]!=cy[i], z= cast(x!=y, dtype))
RUN_TEST_VECTOR_AXPY("z = pow(x,y)", cz[i] = pow(cx[i], cy[i]), z= pow(x,y)) RUN_TEST_VECTOR_AXPY("z = pow(x,y)", cz[i] = pow(cx[i], cy[i]), z= pow(x,y))
#undef RUN_TEST_VECTOR_AXPY #undef RUN_TEST_VECTOR_AXPY