Bugfix in cast and relational operators

This commit is contained in:
Philippe Tillet
2015-01-29 01:00:50 -05:00
parent c7665021d1
commit d4629ba018
13 changed files with 198 additions and 125 deletions

View File

@@ -2,8 +2,8 @@
#define ATIDLAS_ARRAY_H_
#include <iostream>
#include "atidlas/types.h"
#include <CL/cl.hpp>
#include "atidlas/types.h"
#include "atidlas/cl/queues.h"
#include "atidlas/symbolic/expression.h"
@@ -118,7 +118,6 @@ public:
//copy
void copy(void const * data, array & gx, cl::CommandQueue & queue, bool blocking = true);
void copy(array const & gx, void* data, cl::CommandQueue & queue, bool blocking = true);
void copy(void const *data, array &gx, bool blocking = true);
@@ -192,6 +191,9 @@ ATIDLAS_DECLARE_UNARY_OPERATOR(tan)
ATIDLAS_DECLARE_UNARY_OPERATOR(tanh)
ATIDLAS_DECLARE_UNARY_OPERATOR(trans)
array_expression cast(array const &, numeric_type dtype);
array_expression cast(array_expression const &, numeric_type dtype);
array_expression norm(array const &, unsigned int order = 2);
array_expression norm(array_expression const &, unsigned int order = 2);

View File

@@ -235,6 +235,12 @@ public:
mapped_outer(std::string const & scalartype, unsigned int id, node_info info);
};
class mapped_cast : public mapped_object
{
static std::string operator_to_str(operation_node_type type);
public:
mapped_cast(operation_node_type type, unsigned int id);
};
}
#endif

View File

@@ -17,6 +17,8 @@ namespace detail
bool is_vector_reduction(symbolic_expression_node const & node);
bool is_elementwise_operator(op_element const & op);
bool is_elementwise_function(op_element const & op);
bool is_cast(op_element const & op);
}
class scalar;

View File

@@ -193,9 +193,9 @@ public:
typedef std::vector<value_type> container_type;
symbolic_expression(lhs_rhs_element const & lhs, lhs_rhs_element const & rhs, op_element const & op, cl::Context const & context, numeric_type const & dtype);
symbolic_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op);
symbolic_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op);
symbolic_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op);
symbolic_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, numeric_type const & dtype);
symbolic_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype);
symbolic_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype);
container_type & tree();
container_type const & tree() const;
@@ -212,9 +212,9 @@ protected:
struct array_expression: public symbolic_expression
{
array_expression(lhs_rhs_element const & lhs, lhs_rhs_element const & rhs, op_element const & op, cl::Context const & ctx, numeric_type const & dtype, size4 shape);
array_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, size4 shape);
array_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, size4 shape);
array_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, size4 shape);
array_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, numeric_type const & dtype, size4 shape);
array_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype, size4 shape);
array_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype, size4 shape);
size4 shape() const;
array_expression& reshape(int_t size1, int_t size2=1);
int_t nshape() const;

View File

@@ -132,6 +132,7 @@ int_t array::dsize() const
//---------------------------------------
array & array::operator=(array const & rhs)
{
assert(dtype_ == rhs.dtype());
array_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
cl::CommandQueue & queue = cl_ext::get_queue(context_, 0);
model_map_t & mmap = atidlas::get_model_map(queue);
@@ -141,7 +142,8 @@ array & array::operator=(array const & rhs)
array & array::operator=(array_expression const & rhs)
{
array_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), shape_);
assert(dtype_ == rhs.dtype());
array_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), dtype_, shape_);
cl::CommandQueue & queue = cl_ext::get_queue(context_, 0);
model_map_t & mmap = atidlas::get_model_map(queue);
execute(expression, mmap);
@@ -181,7 +183,7 @@ array & array::operator+=(array const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
array & array::operator+=(array_expression const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), shape_); }
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), dtype_, shape_); }
//----
array & array::operator-=(value_scalar const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
@@ -190,7 +192,7 @@ array & array::operator-=(array const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
array & array::operator-=(array_expression const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), shape_); }
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), dtype_, shape_); }
//----
array & array::operator*=(value_scalar const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
@@ -199,7 +201,7 @@ array & array::operator*=(array const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
array & array::operator*=(array_expression const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), shape_); }
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), dtype_, shape_); }
//----
array & array::operator/=(value_scalar const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
@@ -208,7 +210,7 @@ array & array::operator/=(array const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
array & array::operator/=(array_expression const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), shape_); }
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), dtype_, shape_); }
array_expression array::T() const
{ return atidlas::trans(*this) ;}
@@ -385,51 +387,55 @@ bool check_elementwise(U const & u, V const & v)
return max(u.shape())==1 || max(v.shape())==1 || u.shape()==v.shape();
}
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME) \
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \
array_expression OPNAME (array_expression const & x, array_expression const & y) \
{ assert(check_elementwise(x, y));\
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), elementwise_size(x, y) ); } \
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, elementwise_size(x, y)); } \
\
array_expression OPNAME (array const & x, array_expression const & y) \
{ assert(check_elementwise(x, y));\
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), elementwise_size(x, y)); } \
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, elementwise_size(x, y)); } \
\
array_expression OPNAME (array_expression const & x, array const & y) \
{ assert(check_elementwise(x, y));\
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), elementwise_size(x, y)); } \
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, elementwise_size(x, y)); } \
\
array_expression OPNAME (array const & x, array const & y) \
{ assert(check_elementwise(x, y));\
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), x.dtype(), elementwise_size(x, y)); }\
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, elementwise_size(x, y)); }\
\
array_expression OPNAME (array_expression const & x, value_scalar const & y) \
{ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.shape()); } \
{ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, x.shape()); } \
\
array_expression OPNAME (array const & x, value_scalar const & y) \
{ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
{ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\
array_expression OPNAME (value_scalar const & y, array_expression const & x) \
{ return array_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.shape()); } \
{ return array_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE, x.shape()); } \
\
array_expression OPNAME (value_scalar const & y, array const & x) \
{ return array_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
{ return array_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ADD_TYPE, operator +)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_SUB_TYPE, operator -)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_MULT_TYPE, operator *)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_DIV_TYPE, operator /)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ADD_TYPE, operator +, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_SUB_TYPE, operator -, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_MULT_TYPE, operator *, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_DIV_TYPE, operator /, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GREATER_TYPE, operator >)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GEQ_TYPE, operator >=)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LESS_TYPE, operator <)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LEQ_TYPE, operator <=)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, maximum, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, minimum, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow, x.dtype())
namespace detail
{ DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ASSIGN_TYPE, assign, x.dtype()) }
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GREATER_TYPE, operator >, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GEQ_TYPE, operator >=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LESS_TYPE, operator <, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LEQ_TYPE, operator <=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, maximum)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, minimum)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow)
array_expression outer(array const & x, array const & y)
{
@@ -437,10 +443,6 @@ array_expression outer(array const & x, array const & y)
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_OUTER_PROD_TYPE), x.context(), x.dtype(), size4(max(x.shape()), max(y.shape())) );
}
namespace detail
{
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ASSIGN_TYPE, assign)
}
#undef DEFINE_ELEMENT_BINARY_OPERATOR
//---------------------------------------
@@ -452,7 +454,7 @@ array_expression OPNAME (array const & x) \
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
\
array_expression OPNAME (array_expression const & x) \
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.shape()); }
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.dtype(), x.shape()); }
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?OPERATOR_FABS_TYPE:OPERATOR_ABS_TYPE, abs)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ACOS_TYPE, acos)
@@ -476,15 +478,35 @@ DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_TANH_TYPE, tanh)
///*--- Misc----*/
////---------------------------------------
atidlas::array_expression eye(std::size_t M, std::size_t N, atidlas::numeric_type dtype, cl::Context ctx)
inline operation_node_type casted(numeric_type dtype)
{
return array_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N));
switch(dtype)
{
case CHAR_TYPE: return OPERATOR_CAST_CHAR_TYPE;
case UCHAR_TYPE: return OPERATOR_CAST_UCHAR_TYPE;
case SHORT_TYPE: return OPERATOR_CAST_SHORT_TYPE;
case USHORT_TYPE: return OPERATOR_CAST_USHORT_TYPE;
case INT_TYPE: return OPERATOR_CAST_INT_TYPE;
case UINT_TYPE: return OPERATOR_CAST_UINT_TYPE;
case LONG_TYPE: return OPERATOR_CAST_LONG_TYPE;
case ULONG_TYPE: return OPERATOR_CAST_ULONG_TYPE;
case FLOAT_TYPE: return OPERATOR_CAST_FLOAT_TYPE;
case DOUBLE_TYPE: return OPERATOR_CAST_DOUBLE_TYPE;
default: throw unknown_datatype(dtype);
}
}
array_expression cast(array const & x, numeric_type dtype)
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
array_expression cast(array_expression const & x, numeric_type dtype)
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), dtype, x.shape()); }
atidlas::array_expression eye(std::size_t M, std::size_t N, atidlas::numeric_type dtype, cl::Context ctx)
{ return array_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N)); }
atidlas::array_expression zeros(std::size_t M, std::size_t N, atidlas::numeric_type dtype, cl::Context ctx)
{
return array_expression(value_scalar(0), lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N));
}
{ return array_expression(value_scalar(0), lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N)); }
inline size4 flip(size4 const & shape)
{ return size4(shape._2, shape._1);}
@@ -496,7 +518,7 @@ array_expression trans(array const & x) \
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
\
array_expression trans(array_expression const & x) \
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), flip(x.shape())); }
{ return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.dtype(), flip(x.shape())); }
array_expression repmat(array const & A, int_t const & rep1, int_t const & rep2)
{
@@ -515,7 +537,7 @@ array_expression repmat(array_expression const & A, int_t const & rep1, int_t co
infos.rep2 = rep2;
infos.sub1 = A.shape()._1;
infos.sub2 = A.shape()._2;
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), size4(infos.rep1*infos.sub1, infos.rep2*infos.sub2));
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.dtype(), size4(infos.rep1*infos.sub1, infos.rep2*infos.sub2));
}
////---------------------------------------
@@ -540,11 +562,11 @@ array_expression OPNAME(array_expression const & x, int_t axis)\
if(axis < -1 || axis > x.nshape())\
throw std::out_of_range("The axis entry is out of bounds");\
if(axis==-1)\
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY, OP), size4(1));\
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY, OP), x.dtype(), size4(1));\
else if(axis==0)\
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_ROWS_REDUCTION_TYPE_FAMILY, OP), size4(x.shape()._1));\
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_ROWS_REDUCTION_TYPE_FAMILY, OP), x.dtype(), size4(x.shape()._1));\
else\
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OP), size4(x.shape()._2));\
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OP), x.dtype(), size4(x.shape()._2));\
}
DEFINE_REDUCTION(OPERATOR_ADD_TYPE, sum)
@@ -576,7 +598,7 @@ namespace detail
shape._1 = A.shape()._2;
}
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), shape);
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.dtype(), shape);
symbolic_expression_node & res_root = const_cast<symbolic_expression_node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs;
return res;
@@ -593,7 +615,7 @@ namespace detail
type = OPERATOR_MATRIX_PRODUCT_NT_TYPE;
shape._2 = B.shape()._1;
}
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), shape);
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.dtype(), shape);
symbolic_expression_node & res_root = const_cast<symbolic_expression_node &>(res.tree()[res.root()]);
if(B_trans) res_root.rhs = B_root.lhs;
return res;
@@ -615,7 +637,7 @@ namespace detail
else if(!A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_NT_TYPE;
else type = OPERATOR_MATRIX_PRODUCT_NN_TYPE;
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), shape);
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.dtype(), shape);
symbolic_expression_node & res_root = const_cast<symbolic_expression_node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs;
if(B_trans) res_root.rhs = B_root.lhs;
@@ -639,7 +661,7 @@ namespace detail
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
if(A_trans)
{
array_expression tmp(A, repmat(x, 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), size4(N, M));
array_expression tmp(A, repmat(x, 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), A.dtype(), size4(N, M));
//Remove trans
tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs;
return sum(tmp, 1);

View File

@@ -335,4 +335,27 @@ void mapped_outer::postprocess(std::string &res) const
mapped_outer::mapped_outer(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "outer"), binary_leaf(info)
{ }
std::string mapped_cast::operator_to_str(operation_node_type type)
{
switch(type)
{
case OPERATOR_CAST_CHAR_TYPE : return "char";
case OPERATOR_CAST_UCHAR_TYPE : return "uchar";
case OPERATOR_CAST_SHORT_TYPE : return "short";
case OPERATOR_CAST_USHORT_TYPE : return "ushort";
case OPERATOR_CAST_INT_TYPE : return "int";
case OPERATOR_CAST_UINT_TYPE : return "uint";
case OPERATOR_CAST_LONG_TYPE : return "long";
case OPERATOR_CAST_ULONG_TYPE : return "ulong";
case OPERATOR_CAST_HALF_TYPE : return "half";
case OPERATOR_CAST_FLOAT_TYPE : return "float";
case OPERATOR_CAST_DOUBLE_TYPE : return "double";
default : return "invalid";
}
}
mapped_cast::mapped_cast(operation_node_type type, unsigned int id) : mapped_object(operator_to_str(type), id, "cast")
{ }
}

View File

@@ -8,20 +8,7 @@ namespace atidlas
namespace detail
{
bool is_node_leaf(op_element const & op)
{
return op.type==OPERATOR_TRANS_TYPE
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
|| op.type==OPERATOR_MATRIX_DIAG_TYPE
|| op.type==OPERATOR_VDIAG_TYPE
|| op.type==OPERATOR_REPEAT_TYPE
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|| op.type==OPERATOR_OUTER_PROD_TYPE;
}
bool is_scalar_reduction(symbolic_expression_node const & node)
{
@@ -44,13 +31,50 @@ namespace detail
|| op.type== OPERATOR_ELEMENT_PROD_TYPE
|| op.type== OPERATOR_ELEMENT_DIV_TYPE
|| op.type== OPERATOR_MULT_TYPE
|| op.type== OPERATOR_DIV_TYPE;
|| op.type== OPERATOR_DIV_TYPE
|| op.type== OPERATOR_ELEMENT_EQ_TYPE
|| op.type== OPERATOR_ELEMENT_NEQ_TYPE
|| op.type== OPERATOR_ELEMENT_GREATER_TYPE
|| op.type== OPERATOR_ELEMENT_LESS_TYPE
|| op.type== OPERATOR_ELEMENT_GEQ_TYPE
|| op.type== OPERATOR_ELEMENT_LEQ_TYPE ;
}
bool is_cast(op_element const & op)
{
return op.type== OPERATOR_CAST_CHAR_TYPE
|| op.type== OPERATOR_CAST_UCHAR_TYPE
|| op.type== OPERATOR_CAST_SHORT_TYPE
|| op.type== OPERATOR_CAST_USHORT_TYPE
|| op.type== OPERATOR_CAST_INT_TYPE
|| op.type== OPERATOR_CAST_UINT_TYPE
|| op.type== OPERATOR_CAST_LONG_TYPE
|| op.type== OPERATOR_CAST_ULONG_TYPE
|| op.type== OPERATOR_CAST_FLOAT_TYPE
|| op.type== OPERATOR_CAST_DOUBLE_TYPE
;
}
bool is_node_leaf(op_element const & op)
{
return op.type==OPERATOR_TRANS_TYPE
|| op.type==OPERATOR_MATRIX_DIAG_TYPE
|| op.type==OPERATOR_VDIAG_TYPE
|| op.type==OPERATOR_REPEAT_TYPE
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|| op.type==OPERATOR_OUTER_PROD_TYPE
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
;
}
bool is_elementwise_function(op_element const & op)
{
return
op.type == OPERATOR_CAST_CHAR_TYPE
return op.type == OPERATOR_CAST_CHAR_TYPE
|| op.type == OPERATOR_CAST_UCHAR_TYPE
|| op.type == OPERATOR_CAST_SHORT_TYPE
|| op.type == OPERATOR_CAST_USHORT_TYPE
@@ -81,12 +105,6 @@ namespace detail
|| op.type== OPERATOR_TANH_TYPE
|| op.type== OPERATOR_ELEMENT_POW_TYPE
|| op.type== OPERATOR_ELEMENT_EQ_TYPE
|| op.type== OPERATOR_ELEMENT_NEQ_TYPE
|| op.type== OPERATOR_ELEMENT_GREATER_TYPE
|| op.type== OPERATOR_ELEMENT_LESS_TYPE
|| op.type== OPERATOR_ELEMENT_GEQ_TYPE
|| op.type== OPERATOR_ELEMENT_LEQ_TYPE
|| op.type== OPERATOR_ELEMENT_FMAX_TYPE
|| op.type== OPERATOR_ELEMENT_FMIN_TYPE
|| op.type== OPERATOR_ELEMENT_MAX_TYPE
@@ -94,6 +112,8 @@ namespace detail
}
}
//
filter_fun::filter_fun(pred_t pred, std::vector<size_t> & out) : pred_(pred), out_(out)
@@ -161,18 +181,6 @@ const char * evaluate(operation_node_type type)
case OPERATOR_TAN_TYPE : return "tan";
case OPERATOR_TANH_TYPE : return "tanh";
case OPERATOR_CAST_CHAR_TYPE : return "(char)";
case OPERATOR_CAST_UCHAR_TYPE : return "(uchar)";
case OPERATOR_CAST_SHORT_TYPE : return "(short)";
case OPERATOR_CAST_USHORT_TYPE : return "(ushort)";
case OPERATOR_CAST_INT_TYPE : return "(int)";
case OPERATOR_CAST_UINT_TYPE : return "(uint)";
case OPERATOR_CAST_LONG_TYPE : return "(long)";
case OPERATOR_CAST_ULONG_TYPE : return "(ulong)";
case OPERATOR_CAST_HALF_TYPE : return "(half)";
case OPERATOR_CAST_FLOAT_TYPE : return "(float)";
case OPERATOR_CAST_DOUBLE_TYPE : return "(double)";
case OPERATOR_ELEMENT_ARGFMAX_TYPE : return "argfmax";
case OPERATOR_ELEMENT_ARGMAX_TYPE : return "argmax";
case OPERATOR_ELEMENT_ARGFMIN_TYPE : return "argfmin";
@@ -193,12 +201,12 @@ const char * evaluate(operation_node_type type)
case OPERATOR_ACCESS_TYPE : return "[]";
//Relational
case OPERATOR_ELEMENT_EQ_TYPE : return "isequal";
case OPERATOR_ELEMENT_NEQ_TYPE : return "isnotequal";
case OPERATOR_ELEMENT_GREATER_TYPE : return "isgreater";
case OPERATOR_ELEMENT_GEQ_TYPE : return "isgreaterequal";
case OPERATOR_ELEMENT_LESS_TYPE : return "isless";
case OPERATOR_ELEMENT_LEQ_TYPE : return "islessequal";
case OPERATOR_ELEMENT_EQ_TYPE : return "==";
case OPERATOR_ELEMENT_NEQ_TYPE : return "!=";
case OPERATOR_ELEMENT_GREATER_TYPE : return ">";
case OPERATOR_ELEMENT_GEQ_TYPE : return ">=";
case OPERATOR_ELEMENT_LESS_TYPE : return "<";
case OPERATOR_ELEMENT_LEQ_TYPE : return "<=";
case OPERATOR_ELEMENT_FMAX_TYPE : return "fmax";
case OPERATOR_ELEMENT_FMIN_TYPE : return "fmin";
@@ -261,7 +269,9 @@ evaluate_expression_traversal::evaluate_expression_traversal(std::map<std::strin
void evaluate_expression_traversal::call_before_expansion(atidlas::symbolic_expression const & symbolic_expression, int_t root_idx) const
{
symbolic_expression_node const & root_node = symbolic_expression.tree()[root_idx];
if ((root_node.op.type_family==OPERATOR_UNARY_TYPE_FAMILY || detail::is_elementwise_function(root_node.op))
if(detail::is_cast(root_node.op))
str_ += mapping_.at(std::make_pair(root_idx, PARENT_NODE_TYPE))->evaluate(accessors_);
else if ((root_node.op.type_family==OPERATOR_UNARY_TYPE_FAMILY || detail::is_elementwise_function(root_node.op))
&& !detail::is_node_leaf(root_node.op))
str_+=evaluate(root_node.op.type);
str_+="(";

View File

@@ -110,6 +110,8 @@ void base::map_functor::operator()(atidlas::symbolic_expression const & symbolic
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&symbolic_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&symbolic_expression, root_idx, &mapping_)));
else if (detail::is_cast(root_node.op))
mapping_.insert(mapping_type::value_type(key, tools::shared_ptr<mapped_object>(new mapped_cast(root_node.op.type, binder_.get(NULL)))));
}
}

View File

@@ -28,7 +28,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
kernel_generation_stream stream;
std::string init0, upper_bound0, inc0, init1, upper_bound1, inc1;
std::string data_type = append_width("#scalartype",simd_width);
char kprefix[10];
fill_kernel_name(kprefix, label, "d");
@@ -51,7 +51,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
stream.inc_tab();
process(stream, PARENT_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >
("array2", append_width("#scalartype",simd_width) + " #namereg = $VALUE{i*#stride1,j*#stride2};")
("array2", data_type + " #namereg = $VALUE{i*#stride1,j*#stride2};")
("vdiag", "#scalartype #namereg = ((i + ((#diag_offset<0)?#diag_offset:0))!=(j-((#diag_offset>0)?#diag_offset:0)))?0:$VALUE{min(i*#stride1, j*#stride1)};")
("repeat", "#scalartype #namereg = $VALUE{(i%#tuplearg0)*#stride1, (j%#tuplearg1)*#stride2};")
("outer", "#scalartype #namereg = ($LVALUE{i*#stride})*($RVALUE{j*#stride});")
@@ -63,6 +63,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
("repeat", "#namereg")
("array0", "#namereg")
("outer", "#namereg")
("cast", "convert_"+data_type)
, symbolic_expressions, mappings);
process(stream, LHS_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("array2", "$VALUE{i*#stride1,j*#stride2} = #namereg;")

View File

@@ -61,7 +61,9 @@ std::vector<std::string> vaxpy::generate_impl(unsigned int label, symbolic_expre
("matrix_row", "#namereg")
("matrix_column", "#namereg")
("matrix_diag", "#namereg")
("array0", "#namereg"), symbolic_expressions, mappings);
("array0", "#namereg")
("cast", "convert_"+data_type)
, symbolic_expressions, mappings);
process(stream, LHS_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("array1", "#pointer[i*#stride] = #namereg;")
("matrix_row", "$VALUE{#row, i} = #namereg;")
@@ -82,6 +84,7 @@ std::vector<std::string> vaxpy::generate_impl(unsigned int label, symbolic_expre
stream.dec_tab();
stream << "}" << std::endl;
// std::cout << stream.str() << std::endl;
result.push_back(stream.str());
}

View File

@@ -22,7 +22,7 @@ void fill(array const & a, array_infos& i)
}
array_expression array_expression::operator-()
{ return array_expression(*this, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), shape_); }
{ return array_expression(*this, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), dtype_, shape_); }
lhs_rhs_element::lhs_rhs_element()
@@ -80,8 +80,8 @@ symbolic_expression::symbolic_expression(lhs_rhs_element const & lhs, lhs_rhs_el
tree_(1, symbolic_expression_node(lhs, op, rhs)), root_(0), context_(context), dtype_(dtype)
{ }
symbolic_expression::symbolic_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op) :
context_(lhs.context_), dtype_(lhs.dtype_)
symbolic_expression::symbolic_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, numeric_type const & dtype) :
context_(lhs.context_), dtype_(dtype)
{
tree_.reserve(lhs.tree_.size() + 1);
tree_.insert(tree_.end(), lhs.tree_.begin(), lhs.tree_.end());
@@ -89,8 +89,8 @@ symbolic_expression::symbolic_expression(symbolic_expression const & lhs, lhs_rh
root_ = tree_.size() - 1;
}
symbolic_expression::symbolic_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op) :
context_(rhs.context_), dtype_(rhs.dtype_)
symbolic_expression::symbolic_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype) :
context_(rhs.context_), dtype_(dtype)
{
tree_.reserve(rhs.tree_.size() + 1);
tree_.insert(tree_.end(), rhs.tree_.begin(), rhs.tree_.end());
@@ -98,8 +98,8 @@ symbolic_expression::symbolic_expression(lhs_rhs_element const & lhs, symbolic_e
root_ = tree_.size() - 1;
}
symbolic_expression::symbolic_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op):
context_(lhs.context_), dtype_(lhs.dtype_)
symbolic_expression::symbolic_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype):
context_(lhs.context_), dtype_(dtype)
{
std::size_t lsize = lhs.tree_.size();
std::size_t rsize = rhs.tree_.size();
@@ -135,16 +135,16 @@ array_expression::array_expression(lhs_rhs_element const & lhs, lhs_rhs_element
symbolic_expression(lhs, rhs, op, ctx, dtype), shape_(shape)
{ }
array_expression::array_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, size4 shape):
symbolic_expression(lhs, rhs, op), shape_(shape)
array_expression::array_expression(symbolic_expression const & lhs, lhs_rhs_element const & rhs, op_element const & op, numeric_type const & dtype, size4 shape):
symbolic_expression(lhs, rhs, op, dtype), shape_(shape)
{ }
array_expression::array_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, size4 shape):
symbolic_expression(lhs, rhs, op), shape_(shape)
array_expression::array_expression(lhs_rhs_element const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype, size4 shape):
symbolic_expression(lhs, rhs, op, dtype), shape_(shape)
{ }
array_expression::array_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, size4 shape):
symbolic_expression(lhs, rhs, op), shape_(shape)
array_expression::array_expression(symbolic_expression const & lhs, symbolic_expression const & rhs, op_element const & op, numeric_type const & dtype, size4 shape):
symbolic_expression(lhs, rhs, op, dtype), shape_(shape)
{ }
size4 array_expression::shape() const

View File

@@ -13,6 +13,7 @@ void test(T epsilon, simple_matrix_base<T> & cA, simple_matrix_base<T>& cB, simp
using namespace std;
int failure_count = 0;
ad::numeric_type dtype = C.dtype();
cl::Context const & ctx = C.context();
int_t M = cC.size1();
@@ -76,14 +77,15 @@ void test(T epsilon, simple_matrix_base<T> & cA, simple_matrix_base<T>& cB, simp
RUN_TEST("C = A.*B", cC(i,j) = cA(i,j)*cB(i,j), C= A*B)
RUN_TEST("C = A./B", cC(i,j) = cA(i,j)/cB(i,j), C= A/B)
RUN_TEST("C = A==B", cC(i,j) = cA(i,j)==cB(i,j), C= A==B)
RUN_TEST("C = A>=B", cC(i,j) = cA(i,j)>=cB(i,j), C= A>=B)
RUN_TEST("C = A>B", cC(i,j) = cA(i,j)>cB(i,j), C= A>B)
RUN_TEST("C = A<=B", cC(i,j) = cA(i,j)<=cB(i,j), C= A<=B)
RUN_TEST("C = A<B", cC(i,j) = cA(i,j)<cB(i,j), C= A<B)
RUN_TEST("C = A!=B", cC(i,j) = cA(i,j)!=cB(i,j), C= A!=B)
RUN_TEST("C = pow(A,B)", cC(i,j) = pow(cA(i,j), cB(i,j)), C= pow(A,B))
RUN_TEST("C = A==B", cC(i,j) = cA(i,j)==cB(i,j), C= cast(A==B, dtype))
RUN_TEST("C = A>=B", cC(i,j) = cA(i,j)>=cB(i,j), C= cast(A>=B, dtype))
RUN_TEST("C = A>B", cC(i,j) = cA(i,j)>cB(i,j), C= cast(A>B, dtype))
RUN_TEST("C = A<=B", cC(i,j) = cA(i,j)<=cB(i,j), C= cast(A<=B, dtype))
RUN_TEST("C = A<B", cC(i,j) = cA(i,j)<cB(i,j), C= cast(A<B, dtype))
RUN_TEST("C = A!=B", cC(i,j) = cA(i,j)!=cB(i,j), C= cast(A!=B, dtype))
RUN_TEST("C = eye(M, N)", cC(i,j) = i==j, C= eye(M, N, C.dtype(), C.context()))
RUN_TEST("C = outer(x, y)", cC(i,j) = cx[i]*cy[j], C= outer(x,y))

View File

@@ -74,12 +74,12 @@ void test_element_wise_vector(T epsilon, simple_vector_base<T> & cx, simple_vect
RUN_TEST_VECTOR_AXPY("z = x.*y", cz[i] = cx[i]*cy[i], z= x*y)
RUN_TEST_VECTOR_AXPY("z = x./y", cz[i] = cx[i]/cy[i], z= x/y)
RUN_TEST_VECTOR_AXPY("z = x==y", cz[i] = cx[i]==cy[i], z= x==y)
RUN_TEST_VECTOR_AXPY("z = x>=y", cz[i] = cx[i]>=cy[i], z= x>=y)
RUN_TEST_VECTOR_AXPY("z = x>y", cz[i] = cx[i]>cy[i], z= x>y)
RUN_TEST_VECTOR_AXPY("z = x<=y", cz[i] = cx[i]<=cy[i], z= x<=y)
RUN_TEST_VECTOR_AXPY("z = x<y", cz[i] = cx[i]<cy[i], z= x<y)
RUN_TEST_VECTOR_AXPY("z = x!=y", cz[i] = cx[i]!=cy[i], z= x!=y)
RUN_TEST_VECTOR_AXPY("z = x==y", cz[i] = cx[i]==cy[i], z= cast(x==y, dtype))
RUN_TEST_VECTOR_AXPY("z = x>=y", cz[i] = cx[i]>=cy[i], z= cast(x>=y, dtype))
RUN_TEST_VECTOR_AXPY("z = x>y", cz[i] = cx[i]>cy[i], z= cast(x>y, dtype))
RUN_TEST_VECTOR_AXPY("z = x<=y", cz[i] = cx[i]<=cy[i], z= cast(x<=y, dtype))
RUN_TEST_VECTOR_AXPY("z = x<y", cz[i] = cx[i]<cy[i], z= cast(x<y, dtype))
RUN_TEST_VECTOR_AXPY("z = x!=y", cz[i] = cx[i]!=cy[i], z= cast(x!=y, dtype))
RUN_TEST_VECTOR_AXPY("z = pow(x,y)", cz[i] = pow(cx[i], cy[i]), z= pow(x,y))
#undef RUN_TEST_VECTOR_AXPY