low level representation of array

This commit is contained in:
Philippe Tillet
2015-01-18 14:52:45 -05:00
parent 16648f18e0
commit edaa821d93
17 changed files with 243 additions and 194 deletions

View File

@@ -15,6 +15,7 @@ class scalar;
class array: public obj_base class array: public obj_base
{ {
friend array reshape(array const &, int_t, int_t);
public: public:
//1D Constructors //1D Constructors
array(int_t size1, numeric_type dtype, cl::Context context = cl::default_context()); array(int_t size1, numeric_type dtype, cl::Context context = cl::default_context());
@@ -29,9 +30,10 @@ public:
array(array & M, slice const & s1, slice const & s2); array(array & M, slice const & s1, slice const & s2);
//General constructor //General constructor
array(numeric_type dtype, cl::Buffer data, slice const & s1, slice const & s2, cl::Context context = cl::default_context()); array(numeric_type dtype, cl::Buffer data, slice const & s1, slice const & s2, int_t ld, cl::Context context = cl::default_context());
explicit array(array_expression const & proxy); array(array_expression const & proxy);
array(array const &);
//Getters //Getters
numeric_type dtype() const; numeric_type dtype() const;
size4 shape() const; size4 shape() const;
@@ -45,7 +47,6 @@ public:
//Setters //Setters
array& resize(int_t size1, int_t size2=1); array& resize(int_t size1, int_t size2=1);
array& reshape(int_t size1, int_t size2=1);
//Numeric operators //Numeric operators
array& operator=(array const &); array& operator=(array const &);
@@ -114,8 +115,8 @@ public:
atidlas::array_expression eye(std::size_t, std::size_t, atidlas::numeric_type, cl::Context ctx = cl::default_context()); atidlas::array_expression eye(std::size_t, std::size_t, atidlas::numeric_type, cl::Context ctx = cl::default_context());
array_expression zeros(std::size_t N, numeric_type dtype); array_expression zeros(std::size_t N, numeric_type dtype);
array reshape(array const &, int_t, int_t);
//copy //copy
@@ -197,13 +198,6 @@ array_expression norm(array_expression const &, unsigned int order = 2);
#undef ATIDLAS_DECLARE_UNARY_OPERATOR #undef ATIDLAS_DECLARE_UNARY_OPERATOR
struct repeat_infos
{
repeat_infos(size4 const & _sub, size4 const & _rep) : sub(_sub), rep(_rep){ }
size4 sub;
size4 rep;
};
array_expression repmat(array const &, int_t const & rep1, int_t const & rep2); array_expression repmat(array const &, int_t const & rep1, int_t const & rep2);
#define ATIDLAS_DECLARE_REDUCTION(OPNAME) \ #define ATIDLAS_DECLARE_REDUCTION(OPNAME) \

View File

@@ -18,8 +18,8 @@ class symbolic_binder
{ {
public: public:
virtual ~symbolic_binder(); virtual ~symbolic_binder();
virtual bool bind(cl::Buffer const * ph) = 0; virtual bool bind(cl_mem ph) = 0;
virtual unsigned int get(cl::Buffer const * ph) = 0; virtual unsigned int get(cl_mem ph) = 0;
}; };
@@ -27,8 +27,8 @@ class bind_to_handle : public symbolic_binder
{ {
public: public:
bind_to_handle(); bind_to_handle();
bool bind(cl::Buffer const * ph); bool bind(cl_mem ph);
unsigned int get(cl::Buffer const * ph); unsigned int get(cl_mem ph);
private: private:
unsigned int current_arg_; unsigned int current_arg_;
std::map<void*,unsigned int> memory; std::map<void*,unsigned int> memory;
@@ -38,8 +38,8 @@ class bind_all_unique : public symbolic_binder
{ {
public: public:
bind_all_unique(); bind_all_unique();
bool bind(cl::Buffer const *); bool bind(cl_mem);
unsigned int get(cl::Buffer const *); unsigned int get(cl_mem);
private: private:
unsigned int current_arg_; unsigned int current_arg_;
std::map<void*,unsigned int> memory; std::map<void*,unsigned int> memory;

View File

@@ -139,7 +139,7 @@ void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::strin
class symbolic_expression_representation_functor : public traversal_functor{ class symbolic_expression_representation_functor : public traversal_functor{
private: private:
static void append_id(char * & ptr, unsigned int val); static void append_id(char * & ptr, unsigned int val);
void append(cl::Buffer const * h, numeric_type dtype, char prefix) const; void append(cl_mem h, numeric_type dtype, char prefix) const;
void append(lhs_rhs_element const & lhs_rhs) const; void append(lhs_rhs_element const & lhs_rhs) const;
public: public:
symbolic_expression_representation_functor(symbolic_binder & binder, char *& ptr); symbolic_expression_representation_functor(symbolic_binder & binder, char *& ptr);

View File

@@ -79,7 +79,7 @@ protected:
/** @brief Creates a value scalar mapping */ /** @brief Creates a value scalar mapping */
tools::shared_ptr<mapped_object> create(numeric_type dtype, values_holder) const; tools::shared_ptr<mapped_object> create(numeric_type dtype, values_holder) const;
/** @brief Creates a vector mapping */ /** @brief Creates a vector mapping */
tools::shared_ptr<mapped_object> create(array const &) const; tools::shared_ptr<mapped_object> create(array_infos const &) const;
/** @brief Creates a tuple mapping */ /** @brief Creates a tuple mapping */
tools::shared_ptr<mapped_object> create(repeat_infos const &) const; tools::shared_ptr<mapped_object> create(repeat_infos const &) const;
/** @brief Creates a mapping */ /** @brief Creates a mapping */
@@ -101,7 +101,7 @@ protected:
set_arguments_functor(symbolic_binder & binder, unsigned int & current_arg, cl::Kernel & kernel); set_arguments_functor(symbolic_binder & binder, unsigned int & current_arg, cl::Kernel & kernel);
void set_arguments(numeric_type dtype, values_holder const & scal) const; void set_arguments(numeric_type dtype, values_holder const & scal) const;
void set_arguments(array const & ) const; void set_arguments(array_infos const & ) const;
void set_arguments(repeat_infos const & i) const; void set_arguments(repeat_infos const & i) const;
void set_arguments(lhs_rhs_element const & lhs_rhs) const; void set_arguments(lhs_rhs_element const & lhs_rhs) const;

View File

@@ -39,10 +39,10 @@ private:
std::string generate_impl(unsigned int label, char id, const symbolic_expressions_container &symbolic_expressions, const std::vector<mapping_type> &, bool fallback) const; std::string generate_impl(unsigned int label, char id, const symbolic_expressions_container &symbolic_expressions, const std::vector<mapping_type> &, bool fallback) const;
std::vector<std::string> generate_impl(unsigned int label, symbolic_expressions_container const & symbolic_expressions, std::vector<mapping_type> const & mappings) const; std::vector<std::string> generate_impl(unsigned int label, symbolic_expressions_container const & symbolic_expressions, std::vector<mapping_type> const & mappings) const;
void enqueue_block(cl::CommandQueue & queue, int_t M, int_t N, int_t K, void enqueue_block(cl::CommandQueue & queue, int_t M, int_t N, int_t K,
array const & A, array const & B, array const & C, array_infos const & A, array_infos const & B, array_infos const & C,
value_scalar const & alpha, value_scalar const & beta, value_scalar const & alpha, value_scalar const & beta,
std::vector<cl::lazy_compiler> & programs, unsigned int label, int id); std::vector<cl::lazy_compiler> & programs, unsigned int label, int id);
array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap); array_infos create_slice(array_infos & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap);
std::vector<int_t> infos(symbolic_expressions_container const & symbolic_expressions, std::vector<int_t> infos(symbolic_expressions_container const & symbolic_expressions,
lhs_rhs_element & C, lhs_rhs_element & A, lhs_rhs_element & B); lhs_rhs_element & C, lhs_rhs_element & A, lhs_rhs_element & B);
public: public:

View File

@@ -133,6 +133,20 @@ enum symbolic_expression_node_subtype
REPEAT_INFOS_TYPE REPEAT_INFOS_TYPE
}; };
struct array_infos
{
numeric_type dtype;
cl_mem data;
int_t shape1;
int_t shape2;
int_t start1;
int_t start2;
int_t stride1;
int_t stride2;
int_t ld;
};
void fill(array const & a, array_infos& i);
struct lhs_rhs_element struct lhs_rhs_element
{ {
lhs_rhs_element(); lhs_rhs_element();
@@ -144,14 +158,14 @@ struct lhs_rhs_element
symbolic_expression_node_type_family type_family; symbolic_expression_node_type_family type_family;
symbolic_expression_node_subtype subtype; symbolic_expression_node_subtype subtype;
numeric_type dtype; numeric_type dtype;
union union
{ {
unsigned int node_index; unsigned int node_index;
atidlas::array * array;
values_holder vscalar; values_holder vscalar;
atidlas::repeat_infos * tuple; repeat_infos tuple;
array_infos array;
}; };
cl::Buffer memory_;
}; };

View File

@@ -20,6 +20,15 @@ namespace atidlas
inline int_t max(size4 const & s) { return std::max(s._1, s._2); } inline int_t max(size4 const & s) { return std::max(s._1, s._2); }
inline int_t min(size4 const & s) { return std::min(s._1, s._2); } inline int_t min(size4 const & s) { return std::min(s._1, s._2); }
struct repeat_infos
{
int_t sub1;
int_t sub2;
int_t rep1;
int_t rep2;
};
enum numeric_type enum numeric_type
{ {
INVALID_NUMERIC_TYPE = 0, INVALID_NUMERIC_TYPE = 0,

View File

@@ -75,9 +75,9 @@ INSTANTIATE(cl_double);
#undef INSTANTIATE #undef INSTANTIATE
// General // General
array::array(numeric_type dtype, cl::Buffer data, slice const & s1, slice const & s2, cl::Context context): array::array(numeric_type dtype, cl::Buffer data, slice const & s1, slice const & s2, int_t ld, cl::Context context):
dtype_(dtype), shape_(s1.size, s2.size), start_(s1.start, s2.start), stride_(s1.stride, s2.stride), dtype_(dtype), shape_(s1.size, s2.size), start_(s1.start, s2.start), stride_(s1.stride, s2.stride),
ld_(shape_._1), context_(context), data_(data) ld_(ld), context_(context), data_(data)
{ } { }
array::array(array_expression const & proxy) : array::array(array_expression const & proxy) :
@@ -88,6 +88,15 @@ array::array(array_expression const & proxy) :
*this = proxy; *this = proxy;
} }
array::array(array const & other) :
dtype_(other.dtype()),
shape_(other.shape()), start_(0,0), stride_(1, 1), ld_(shape_._1),
context_(other.context()), data_(context_, CL_MEM_READ_WRITE, size_of(dtype_)*dsize())
{
*this = other;
}
/*--- Getters ---*/ /*--- Getters ---*/
numeric_type array::dtype() const numeric_type array::dtype() const
{ return dtype_; } { return dtype_; }
@@ -116,15 +125,6 @@ cl::Buffer const & array::data() const
int_t array::dsize() const int_t array::dsize() const
{ return ld_*shape_._2; } { return ld_*shape_._2; }
/*--- Setters ---*/
array& array::reshape(int_t size1, int_t size2)
{
assert(size1*size2==prod(shape_));
shape_ = size4(size1, size2);
ld_ = size1;
return *this;
}
/*--- Assignment Operators ----*/ /*--- Assignment Operators ----*/
//--------------------------------------- //---------------------------------------
array & array::operator=(array const & rhs) array & array::operator=(array const & rhs)
@@ -215,7 +215,7 @@ void copy(cl::Context & ctx, cl::Buffer const & data, T value)
} }
scalar::scalar(numeric_type dtype, const cl::Buffer &data, int_t offset, cl::Context context): array(dtype, data, _(offset, offset+1), _(1,1), context) scalar::scalar(numeric_type dtype, const cl::Buffer &data, int_t offset, cl::Context context): array(dtype, data, _(offset, offset+1), _(1,1), 1, context)
{ } { }
scalar::scalar(value_scalar value, cl::Context context) : array(1, value.dtype(), context) scalar::scalar(value_scalar value, cl::Context context) : array(1, value.dtype(), context)
@@ -457,18 +457,22 @@ array_expression trans(array_expression const & x) \
array_expression repmat(array const & A, int_t const & rep1, int_t const & rep2) array_expression repmat(array const & A, int_t const & rep1, int_t const & rep2)
{ {
static repeat_infos infos(A.shape(), size4(rep1, rep2)); repeat_infos infos;
infos = repeat_infos(A.shape(), size4(rep1, rep2)); infos.rep1 = rep1;
size4 newshape = prod(infos.sub, infos.rep); infos.rep2 = rep2;
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), newshape); infos.sub1 = A.shape()._1;
infos.sub2 = A.shape()._2;
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), size4(infos.rep1*infos.sub1, infos.rep2*infos.sub2));
} }
array_expression repmat(array_expression const & A, int_t const & rep1, int_t const & rep2) array_expression repmat(array_expression const & A, int_t const & rep1, int_t const & rep2)
{ {
static repeat_infos infos(A.shape(), size4(rep1, rep2)); repeat_infos infos;
infos = repeat_infos(A.shape(), size4(rep1, rep2)); infos.rep1 = rep1;
size4 newshape = prod(infos.sub, infos.rep); infos.rep2 = rep2;
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), newshape); infos.sub1 = A.shape()._1;
infos.sub2 = A.shape()._2;
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), size4(infos.rep1*infos.sub1, infos.rep2*infos.sub2));
} }
////--------------------------------------- ////---------------------------------------
@@ -575,13 +579,12 @@ namespace detail
return res; return res;
} }
template<class T> template<class T>
array_expression matvecprod(array const & A, T const & x) array_expression matvecprod(array const & A, T const & x)
{ {
int_t M = A.shape()._1; int_t M = A.shape()._1;
int_t N = A.shape()._2; int_t N = A.shape()._2;
return sum(A*repmat(const_cast<T&>(x).reshape(1, N), M, 1), 0); return sum(A*repmat(reshape(x, 1, N), M, 1), 0);
} }
template<class T> template<class T>
@@ -593,13 +596,13 @@ namespace detail
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE; bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
if(A_trans) if(A_trans)
{ {
array_expression tmp(A, repmat(const_cast<T&>(x), 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), size4(N, M)); array_expression tmp(A, repmat(x, 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), size4(N, M));
//Remove trans //Remove trans
tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs; tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs;
return sum(tmp, 1); return sum(tmp, 1);
} }
else else
return sum(A*repmat(const_cast<T&>(x).reshape(1, N), M, 1), 0); return sum(A*repmat(reshape(x, 1, N), M, 1), 0);
} }
@@ -611,6 +614,15 @@ namespace detail
} }
array reshape(array const & a, int_t size1, int_t size2)
{
array tmp(a);
tmp.shape_._1 = size1;
tmp.shape_._2 = size2;
return tmp;
}
#define DEFINE_DOT(LTYPE, RTYPE) \ #define DEFINE_DOT(LTYPE, RTYPE) \
array_expression dot(LTYPE const & x, RTYPE const & y)\ array_expression dot(LTYPE const & x, RTYPE const & y)\
{\ {\

View File

@@ -9,20 +9,20 @@ bind_to_handle::bind_to_handle() : current_arg_(0)
{ } { }
// //
bool bind_to_handle::bind(cl::Buffer const * ph) bool bind_to_handle::bind(cl_mem ph)
{ return (ph==NULL)?true:memory.insert(std::make_pair((void*)ph, current_arg_)).second; } { return (ph==NULL)?true:memory.insert(std::make_pair((void*)ph, current_arg_)).second; }
unsigned int bind_to_handle::get(cl::Buffer const * ph) unsigned int bind_to_handle::get(cl_mem ph)
{ return bind(ph)?current_arg_++:memory.at((void*)ph); } { return bind(ph)?current_arg_++:memory.at(ph); }
// //
bind_all_unique::bind_all_unique() : current_arg_(0) bind_all_unique::bind_all_unique() : current_arg_(0)
{ } { }
bool bind_all_unique::bind(cl::Buffer const *) bool bind_all_unique::bind(cl_mem)
{return true;} {return true;}
unsigned int bind_all_unique::get(cl::Buffer const *) unsigned int bind_all_unique::get(cl_mem)
{ return current_arg_++;} { return current_arg_++;}
} }

View File

@@ -419,7 +419,7 @@ void symbolic_expression_representation_functor::append_id(char * & ptr, unsigne
} }
} }
void symbolic_expression_representation_functor::append(cl::Buffer const * h, numeric_type dtype, char prefix) const void symbolic_expression_representation_functor::append(cl_mem h, numeric_type dtype, char prefix) const
{ {
*ptr_++=prefix; *ptr_++=prefix;
*ptr_++=(char)dtype; *ptr_++=(char)dtype;
@@ -429,7 +429,7 @@ void symbolic_expression_representation_functor::append(cl::Buffer const * h, nu
void symbolic_expression_representation_functor::append(lhs_rhs_element const & lhs_rhs) const void symbolic_expression_representation_functor::append(lhs_rhs_element const & lhs_rhs) const
{ {
if(lhs_rhs.subtype==DENSE_ARRAY_TYPE) if(lhs_rhs.subtype==DENSE_ARRAY_TYPE)
append(&lhs_rhs.array->data(), lhs_rhs.array->dtype(), (char)(((int)'0')+lhs_rhs.array->nshape())); append(lhs_rhs.array.data, lhs_rhs.array.dtype, (char)(((int)'0')+((int)(lhs_rhs.array.shape1>1) + (int)(lhs_rhs.array.shape2>1))));
} }
symbolic_expression_representation_functor::symbolic_expression_representation_functor(symbolic_binder & binder, char *& ptr) : binder_(binder), ptr_(ptr){ } symbolic_expression_representation_functor::symbolic_expression_representation_functor(symbolic_binder & binder, char *& ptr) : binder_(binder), ptr_(ptr){ }

View File

@@ -42,19 +42,19 @@ tools::shared_ptr<mapped_object> base::map_functor::create(numeric_type dtype, v
} }
/** @brief Vector mapping */ /** @brief Vector mapping */
tools::shared_ptr<mapped_object> base::map_functor::create(array const & a) const tools::shared_ptr<mapped_object> base::map_functor::create(array_infos const & a) const
{ {
std::string dtype = numeric_type_to_string(a.dtype()); std::string dtype = numeric_type_to_string(a.dtype);
unsigned int id = binder_.get(&a.data()); unsigned int id = binder_.get(a.data);
if(max(a.shape())==1) if(a.shape1==1 && a.shape2==1)
return tools::shared_ptr<mapped_object>(new mapped_scalar(dtype, id)); return tools::shared_ptr<mapped_object>(new mapped_scalar(dtype, id));
else else
{ {
//Column vector //Column vector
if(a.shape()._1>1 && a.shape()._2==1) if(a.shape1>1 && a.shape2==1)
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'c')); return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'c'));
//Row vector //Row vector
else if(a.shape()._1==1 && a.shape()._2>1) else if(a.shape1==1 && a.shape2>1)
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'r')); return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'r'));
//Matrix //Matrix
else else
@@ -72,9 +72,9 @@ tools::shared_ptr<mapped_object> base::map_functor::create(lhs_rhs_element const
{ {
switch(lhs_rhs.type_family) switch(lhs_rhs.type_family)
{ {
case INFOS_TYPE_FAMILY: return create(*lhs_rhs.tuple); case INFOS_TYPE_FAMILY: return create(lhs_rhs.tuple);
case VALUE_TYPE_FAMILY: return create(lhs_rhs.dtype, lhs_rhs.vscalar); case VALUE_TYPE_FAMILY: return create(lhs_rhs.dtype, lhs_rhs.vscalar);
case ARRAY_TYPE_FAMILY: return create(*lhs_rhs.array); case ARRAY_TYPE_FAMILY: return create(lhs_rhs.array);
default: throw ""; default: throw "";
} }
} }
@@ -136,32 +136,32 @@ void base::set_arguments_functor::set_arguments(numeric_type dtype, values_holde
} }
/** @brief Vector mapping */ /** @brief Vector mapping */
void base::set_arguments_functor::set_arguments(array const & x) const void base::set_arguments_functor::set_arguments(array_infos const & x) const
{ {
bool is_bound = binder_.bind(&x.data()); bool is_bound = binder_.bind(x.data);
if (is_bound) if (is_bound)
{ {
//scalar //scalar
if(x.nshape()==0) if(x.shape1==1 && x.shape2==1)
{ {
kernel_.setArg(current_arg_++, x.data()); kernel_.setArg(current_arg_++, x.data);
} }
//array //array
else else
{ {
kernel_.setArg(current_arg_++, x.data()); kernel_.setArg(current_arg_++, x.data);
if(x.nshape()==1) if(x.shape1==1 || x.shape2==1)
{ {
kernel_.setArg(current_arg_++, cl_uint(max(x.start()))); kernel_.setArg(current_arg_++, cl_uint(std::max(x.start1, x.start2)));
kernel_.setArg(current_arg_++, cl_uint(max(x.stride()))); kernel_.setArg(current_arg_++, cl_uint(std::max(x.stride1, x.stride2)));
} }
else else
{ {
kernel_.setArg(current_arg_++, cl_uint(x.ld())); kernel_.setArg(current_arg_++, cl_uint(x.ld));
kernel_.setArg(current_arg_++, cl_uint(x.start()._1)); kernel_.setArg(current_arg_++, cl_uint(x.start1));
kernel_.setArg(current_arg_++, cl_uint(x.start()._2)); kernel_.setArg(current_arg_++, cl_uint(x.start2));
kernel_.setArg(current_arg_++, cl_uint(x.stride()._1)); kernel_.setArg(current_arg_++, cl_uint(x.stride1));
kernel_.setArg(current_arg_++, cl_uint(x.stride()._2)); kernel_.setArg(current_arg_++, cl_uint(x.stride2));
} }
} }
} }
@@ -169,10 +169,10 @@ void base::set_arguments_functor::set_arguments(array const & x) const
void base::set_arguments_functor::set_arguments(repeat_infos const & i) const void base::set_arguments_functor::set_arguments(repeat_infos const & i) const
{ {
kernel_.setArg(current_arg_++, cl_uint(i.sub._1)); kernel_.setArg(current_arg_++, cl_uint(i.sub1));
kernel_.setArg(current_arg_++, cl_uint(i.sub._2)); kernel_.setArg(current_arg_++, cl_uint(i.sub2));
kernel_.setArg(current_arg_++, cl_uint(i.rep._1)); kernel_.setArg(current_arg_++, cl_uint(i.rep1));
kernel_.setArg(current_arg_++, cl_uint(i.rep._2)); kernel_.setArg(current_arg_++, cl_uint(i.rep2));
} }
void base::set_arguments_functor::set_arguments(lhs_rhs_element const & lhs_rhs) const void base::set_arguments_functor::set_arguments(lhs_rhs_element const & lhs_rhs) const
@@ -180,8 +180,8 @@ void base::set_arguments_functor::set_arguments(lhs_rhs_element const & lhs_rhs)
switch(lhs_rhs.type_family) switch(lhs_rhs.type_family)
{ {
case VALUE_TYPE_FAMILY: return set_arguments(lhs_rhs.dtype, lhs_rhs.vscalar); case VALUE_TYPE_FAMILY: return set_arguments(lhs_rhs.dtype, lhs_rhs.vscalar);
case ARRAY_TYPE_FAMILY: return set_arguments(*lhs_rhs.array); case ARRAY_TYPE_FAMILY: return set_arguments(lhs_rhs.array);
case INFOS_TYPE_FAMILY: return set_arguments(*lhs_rhs.tuple); case INFOS_TYPE_FAMILY: return set_arguments(lhs_rhs.tuple);
default: throw "oh noez"; default: throw "oh noez";
} }
} }
@@ -376,7 +376,7 @@ bool base::has_strided_access(symbolic_expressions_container const & symbolic_ex
{ {
std::vector<lhs_rhs_element> arrays = filter_elements(DENSE_ARRAY_TYPE, **it); std::vector<lhs_rhs_element> arrays = filter_elements(DENSE_ARRAY_TYPE, **it);
for (std::vector<lhs_rhs_element>::iterator itt = arrays.begin(); itt != arrays.end(); ++itt) for (std::vector<lhs_rhs_element>::iterator itt = arrays.begin(); itt != arrays.end(); ++itt)
if(max(itt->array->stride())>1) if(std::max(itt->array.stride1, itt->array.stride2)>1)
return true; return true;
if(filter_nodes(&is_strided, **it, true).empty()==false) if(filter_nodes(&is_strided, **it, true).empty()==false)
return true; return true;
@@ -388,13 +388,13 @@ int_t base::vector_size(symbolic_expression_node const & node)
{ {
using namespace tools; using namespace tools;
if (node.op.type==OPERATOR_MATRIX_DIAG_TYPE) if (node.op.type==OPERATOR_MATRIX_DIAG_TYPE)
return std::min<int_t>(node.lhs.array->shape()._1, node.lhs.array->shape()._2); return std::min<int_t>(node.lhs.array.shape1, node.lhs.array.shape2);
else if (node.op.type==OPERATOR_MATRIX_ROW_TYPE) else if (node.op.type==OPERATOR_MATRIX_ROW_TYPE)
return node.lhs.array->shape()._2; return node.lhs.array.shape2;
else if (node.op.type==OPERATOR_MATRIX_COLUMN_TYPE) else if (node.op.type==OPERATOR_MATRIX_COLUMN_TYPE)
return node.lhs.array->shape()._1; return node.lhs.array.shape1;
else else
return max(node.lhs.array->shape()); return std::max(node.lhs.array.shape1, node.lhs.array.shape2);
} }
@@ -402,13 +402,13 @@ std::pair<int_t, int_t> base::matrix_size(symbolic_expression_node const & node)
{ {
if (node.op.type==OPERATOR_VDIAG_TYPE) if (node.op.type==OPERATOR_VDIAG_TYPE)
{ {
int_t size = node.lhs.array->shape()._1; int_t size = node.lhs.array.shape1;
return std::make_pair(size,size); return std::make_pair(size,size);
} }
else if(node.op.type==OPERATOR_REPEAT_TYPE) else if(node.op.type==OPERATOR_REPEAT_TYPE)
return std::make_pair(node.lhs.array->shape()._1*node.rhs.tuple->rep._1, node.lhs.array->shape()._2*node.rhs.tuple->rep._2); return std::make_pair(node.lhs.array.shape1*node.rhs.tuple.rep1, node.lhs.array.shape2*node.rhs.tuple.rep2);
else else
return std::make_pair(node.lhs.array->shape()._1,node.lhs.array->shape()._2); return std::make_pair(node.lhs.array.shape1,node.lhs.array.shape2);
} }
void base::element_wise_loop_1D(kernel_generation_stream & stream, loop_body_base const & loop_body, void base::element_wise_loop_1D(kernel_generation_stream & stream, loop_body_base const & loop_body,
@@ -482,7 +482,7 @@ unsigned int base::align(unsigned int to_round, unsigned int base)
return (to_round + base - 1)/base * base; return (to_round + base - 1)/base * base;
} }
inline tools::shared_ptr<symbolic_binder> base::make_binder() tools::shared_ptr<symbolic_binder> base::make_binder()
{ {
if (binding_policy_==BIND_TO_HANDLE) if (binding_policy_==BIND_TO_HANDLE)
return tools::shared_ptr<symbolic_binder>(new bind_to_handle()); return tools::shared_ptr<symbolic_binder>(new bind_to_handle());
@@ -531,7 +531,7 @@ bool base_impl<TType, PType>::has_misaligned_offset(symbolic_expressions_contain
{ {
std::vector<lhs_rhs_element> arrays = filter_elements(DENSE_ARRAY_TYPE, **it); std::vector<lhs_rhs_element> arrays = filter_elements(DENSE_ARRAY_TYPE, **it);
for (std::vector<lhs_rhs_element>::iterator itt = arrays.begin(); itt != arrays.end(); ++itt) for (std::vector<lhs_rhs_element>::iterator itt = arrays.begin(); itt != arrays.end(); ++itt)
if (max(itt->array->start())>0) if (itt->array.start1>0 || itt->array.start2>0)
return true; return true;
} }
return false; return false;

View File

@@ -35,6 +35,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
stream.inc_tab(); stream.inc_tab();
process(stream, PARENT_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("scalar", "#scalartype #namereg = *#pointer;") process(stream, PARENT_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("scalar", "#scalartype #namereg = *#pointer;")
("array1", "#pointer += #start;")
("array2", "#pointer = &$VALUE{#start1, #start2};"), symbolic_expressions, mappings); ("array2", "#pointer = &$VALUE{#start1, #start2};"), symbolic_expressions, mappings);
fetching_loop_info(p_.fetching_policy, "M", stream, init0, upper_bound0, inc0, "get_global_id(0)", "get_global_size(0)"); fetching_loop_info(p_.fetching_policy, "M", stream, init0, upper_bound0, inc0, "get_global_id(0)", "get_global_size(0)");

View File

@@ -562,11 +562,11 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
} }
void mproduct::enqueue_block(cl::CommandQueue & queue, int_t M, int_t N, int_t K, void mproduct::enqueue_block(cl::CommandQueue & queue, int_t M, int_t N, int_t K,
array const & A, array const & B, array const & C, array_infos const & A, array_infos const & B, array_infos const & C,
value_scalar const & alpha, value_scalar const & beta, value_scalar const & alpha, value_scalar const & beta,
std::vector<cl::lazy_compiler> & programs, unsigned int label, int id) std::vector<cl::lazy_compiler> & programs, unsigned int label, int id)
{ {
if (min(A.shape())==0 || min(B.shape())==0 || min(C.shape())==0) if (A.shape1==0 || A.shape2==0 || B.shape1==0 || B.shape2==0 || C.shape1==0 || C.shape2==0)
return; return;
char kname[10]; char kname[10];
@@ -594,13 +594,21 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
queue.enqueueNDRangeKernel(kernel, cl::NullRange, grange, lrange); queue.enqueueNDRangeKernel(kernel, cl::NullRange, grange, lrange);
} }
array mproduct::create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap) array_infos mproduct::create_slice(array_infos & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap)
{ {
slice s0(s0_0, s0_1); slice s1(s0_0, s0_1);
slice s1(s1_0, s1_1); slice s2(s1_0, s1_1);
if (swap) if (swap)
std::swap(s0, s1); std::swap(s1, s2);
return array(M, s0, s1);
array_infos result = M;
result.shape1 = s1.size;
result.shape2 = s2.size;
result.start1 = M.start1 + M.stride1*s1.start;
result.start2 = M.start2 + M.stride2*s2.start;
result.stride1 = s1.stride*M.stride1;
result.stride2 = s2.stride*M.stride2;
return result;
} }
std::vector<int_t> mproduct::infos(symbolic_expressions_container const & symbolic_expressions, std::vector<int_t> mproduct::infos(symbolic_expressions_container const & symbolic_expressions,
@@ -618,14 +626,15 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
int_t B_idx = array[root].rhs.node_index; int_t B_idx = array[root].rhs.node_index;
B = array[B_idx].rhs; B = array[B_idx].rhs;
int_t M = C.array->shape()._1; int_t M = C.array.shape1;
int_t N = C.array->shape()._2; int_t N = C.array.shape2;
int_t K = (A_trans_=='T')?A.array->shape()._1:A.array->shape()._2; int_t K = (A_trans_=='T')?A.array.shape1:A.array.shape2;
return tools::make_vector<int_t>() << M << N << K; return tools::make_vector<int_t>() << M << N << K;
} }
mproduct::mproduct(mproduct_parameters const & parameters, char A_trans, char B_trans) : base_impl<mproduct, mproduct_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans){ } mproduct::mproduct(mproduct_parameters const & parameters, char A_trans, char B_trans) : base_impl<mproduct, mproduct_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans)
{ }
std::vector<int_t> mproduct::input_sizes(symbolic_expressions_container const & symbolic_expressions) std::vector<int_t> mproduct::input_sizes(symbolic_expressions_container const & symbolic_expressions)
{ {
@@ -633,10 +642,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return infos(symbolic_expressions, d0, d1, d2); return infos(symbolic_expressions, d0, d1, d2);
} }
void mproduct::enqueue(cl::CommandQueue & queue, void mproduct::enqueue(cl::CommandQueue & queue, std::vector<cl::lazy_compiler> & programs, unsigned int label, symbolic_expressions_container const & symbolic_expressions)
std::vector<cl::lazy_compiler> & programs,
unsigned int label,
symbolic_expressions_container const & symbolic_expressions)
{ {
using namespace tools; using namespace tools;
@@ -648,15 +654,15 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
int_t K = MNK[2]; int_t K = MNK[2];
array* pA = A.array; array_infos& pA = A.array;
array* pB = B.array; array_infos& pB = B.array;
array* pC = C.array; array_infos& pC = C.array;
int_t ldstrideA = pA->stride()._1; int_t ldstrideA = pA.stride1;
int_t ldstrideB = pB->stride()._1; int_t ldstrideB = pB.stride1;
int_t ldstrideC = pC->stride()._1; int_t ldstrideC = pC.stride1;
int_t ldstartA = pA->start()._1; int_t ldstartA = pA.start1;
int_t ldstartB = pB->start()._1; int_t ldstartB = pB.start1;
bool swap_A = (A_trans_=='T'); bool swap_A = (A_trans_=='T');
bool swap_B = (B_trans_=='T'); bool swap_B = (B_trans_=='T');
@@ -669,13 +675,12 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
value_scalar _1d(cl_double(1)); value_scalar _1d(cl_double(1));
value_scalar* _1 = C.dtype==FLOAT_TYPE?&_1f:&_1d; value_scalar* _1 = C.dtype==FLOAT_TYPE?&_1f:&_1d;
if (M < p_.mL || N < p_.nL || K < p_.kL || if (M < p_.mL || N < p_.nL || K < p_.kL || ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1
ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1 || || (p_.simd_width>1 && (ldstartA % p_.simd_width > 0 || ldstartB % p_.simd_width > 0)))
(p_.simd_width>1 && (ldstartA % p_.simd_width > 0 || ldstartB % p_.simd_width > 0)))
{ {
enqueue_block(queue, M, N, K, create_slice(*pA, 0, M, 0, K, swap_A), enqueue_block(queue, M, N, K, create_slice(pA, 0, M, 0, K, swap_A),
create_slice(*pB, 0, K, 0, N, swap_B), create_slice(pB, 0, K, 0, N, swap_B),
create_slice(*pC, 0, M, 0, N, false), *_1, *_0, programs, label, 1); create_slice(pC, 0, M, 0, N, false), *_1, *_0, programs, label, 1);
return; return;
} }
@@ -683,17 +688,17 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
int_t lN = N / p_.nL * p_.nL; int_t lN = N / p_.nL * p_.nL;
int_t lK = K / p_.kL * p_.kL; int_t lK = K / p_.kL * p_.kL;
enqueue_block(queue, lM, lN, lK, create_slice(*pA, 0, lM, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), *_1, *_0, programs, label, 0); enqueue_block(queue, lM, lN, lK, create_slice(pA, 0, lM, 0, lK, swap_A), create_slice(pB, 0, lK, 0, lN, swap_B), create_slice(pC, 0, lM, 0, lN, false), *_1, *_0, programs, label, 0);
enqueue_block(queue, lM, lN, K - lK, create_slice(*pA, 0, lM, lK, K, swap_A), create_slice(*pB, lK, K, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), *_1, *_1, programs, label, 1); enqueue_block(queue, lM, lN, K - lK, create_slice(pA, 0, lM, lK, K, swap_A), create_slice(pB, lK, K, 0, lN, swap_B), create_slice(pC, 0, lM, 0, lN, false), *_1, *_1, programs, label, 1);
enqueue_block(queue, lM, N - lN, lK, create_slice(*pA, 0, lM, 0, lK, swap_A), create_slice(*pB, 0, lK, lN, N, swap_B), create_slice(*pC, 0, lM, lN, N, false), *_1, *_0, programs, label, 1); enqueue_block(queue, lM, N - lN, lK, create_slice(pA, 0, lM, 0, lK, swap_A), create_slice(pB, 0, lK, lN, N, swap_B), create_slice(pC, 0, lM, lN, N, false), *_1, *_0, programs, label, 1);
enqueue_block(queue, lM, N - lN, K - lK, create_slice(*pA, 0, lM, lK, K, swap_A), create_slice(*pB, lK, K, lN, N, swap_B), create_slice(*pC, 0, lM, lN, N, false), *_1, *_1, programs, label, 1); enqueue_block(queue, lM, N - lN, K - lK, create_slice(pA, 0, lM, lK, K, swap_A), create_slice(pB, lK, K, lN, N, swap_B), create_slice(pC, 0, lM, lN, N, false), *_1, *_1, programs, label, 1);
enqueue_block(queue, M - lM, lN, lK, create_slice(*pA, lM, M, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, lN, swap_B), create_slice(*pC, lM, M, 0, lN, false), *_1, *_0, programs, label, 1); enqueue_block(queue, M - lM, lN, lK, create_slice(pA, lM, M, 0, lK, swap_A), create_slice(pB, 0, lK, 0, lN, swap_B), create_slice(pC, lM, M, 0, lN, false), *_1, *_0, programs, label, 1);
enqueue_block(queue, M - lM, lN, K - lK, create_slice(*pA, lM, M, lK, K, swap_A), create_slice(*pB, lK, K, 0, lN, swap_B), create_slice(*pC, lM, M, 0, lN, false), *_1, *_1, programs, label, 1); enqueue_block(queue, M - lM, lN, K - lK, create_slice(pA, lM, M, lK, K, swap_A), create_slice(pB, lK, K, 0, lN, swap_B), create_slice(pC, lM, M, 0, lN, false), *_1, *_1, programs, label, 1);
enqueue_block(queue, M - lM, N - lN, lK, create_slice(*pA, lM, M, 0, lK, swap_A), create_slice(*pB, 0, lK, lN, N, swap_B), create_slice(*pC, lM, M, lN, N, false), *_1, *_0, programs, label, 1); enqueue_block(queue, M - lM, N - lN, lK, create_slice(pA, lM, M, 0, lK, swap_A), create_slice(pB, 0, lK, lN, N, swap_B), create_slice(pC, lM, M, lN, N, false), *_1, *_0, programs, label, 1);
enqueue_block(queue, M - lM, N - lN, K - lK, create_slice(*pA, lM, M, lK, K, swap_A), create_slice(*pB, lK, K, lN, N, swap_B), create_slice(*pC, lM, M, lN, N, false), *_1, *_1, programs, label, 1); enqueue_block(queue, M - lM, N - lN, K - lK, create_slice(pA, lM, M, lK, K, swap_A), create_slice(pB, lK, K, lN, N, swap_B), create_slice(pC, lM, M, lN, N, false), *_1, *_1, programs, label, 1);
} }
// //

View File

@@ -161,9 +161,9 @@ namespace atidlas
//Init //Init
expression_type current_type; expression_type current_type;
if(root_save.lhs.array->nshape()==0) if(root_save.lhs.array.shape1==1 && root_save.lhs.array.shape2==1)
current_type = SCALAR_AXPY_TYPE; current_type = SCALAR_AXPY_TYPE;
else if(root_save.lhs.array->nshape()==1) else if(root_save.lhs.array.shape1==1 || root_save.lhs.array.shape2==1)
current_type=VECTOR_AXPY_TYPE; current_type=VECTOR_AXPY_TYPE;
else else
current_type=MATRIX_AXPY_TYPE; current_type=MATRIX_AXPY_TYPE;
@@ -186,15 +186,15 @@ namespace atidlas
case SCALAR_AXPY_TYPE: case SCALAR_AXPY_TYPE:
case REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(1, dtype, context)); break; case REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(1, dtype, context)); break;
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._1, dtype, context)); break; case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape1, dtype, context)); break;
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._1, dtype, context)); break; case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape1, dtype, context)); break;
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._2, dtype, context)); break; case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape2, dtype, context)); break;
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._1, lmost.lhs.array->shape()._2, dtype, context)); break; case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape1, lmost.lhs.array.shape2, dtype, context)); break;
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._1, node.rhs.array->shape()._2, dtype, context)); break; case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape1, node.rhs.array.shape2, dtype, context)); break;
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._1, node.rhs.array->shape()._1, dtype, context)); break; case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape1, node.rhs.array.shape1, dtype, context)); break;
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._2, node.rhs.array->shape()._2, dtype, context)); break; case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape2, node.rhs.array.shape2, dtype, context)); break;
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._2, node.rhs.array->shape()._1, dtype, context)); break; case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape2, node.rhs.array.shape1, dtype, context)); break;
default: throw "This shouldn't happen. Ever."; default: throw "This shouldn't happen. Ever.";
} }
@@ -213,7 +213,7 @@ namespace atidlas
rit->second->dtype = dtype; rit->second->dtype = dtype;
rit->second->type_family = ARRAY_TYPE_FAMILY; rit->second->type_family = ARRAY_TYPE_FAMILY;
rit->second->subtype = DENSE_ARRAY_TYPE; rit->second->subtype = DENSE_ARRAY_TYPE;
rit->second->array = (array*)tmp.get(); fill((array&)*tmp, rit->second->array);
} }
/*-----Compute final expression-----*/ /*-----Compute final expression-----*/

View File

@@ -8,6 +8,19 @@
namespace atidlas namespace atidlas
{ {
void fill(array const & a, array_infos& i)
{
i.dtype = a.dtype();
i.data = a.data()();
i.shape1 = a.shape()._1;
i.shape2 = a.shape()._2;
i.start1 = a.start()._1;
i.start2 = a.start()._2;
i.stride1 = a.stride()._1;
i.stride2 = a.stride()._2;
i.ld = a.ld();
}
lhs_rhs_element::lhs_rhs_element() lhs_rhs_element::lhs_rhs_element()
{ {
type_family = INVALID_TYPE_FAMILY; type_family = INVALID_TYPE_FAMILY;
@@ -28,7 +41,8 @@ lhs_rhs_element::lhs_rhs_element(atidlas::array const & x)
type_family = ARRAY_TYPE_FAMILY; type_family = ARRAY_TYPE_FAMILY;
subtype = DENSE_ARRAY_TYPE; subtype = DENSE_ARRAY_TYPE;
dtype = x.dtype(); dtype = x.dtype();
array = (atidlas::array*)&x; fill(x, array);
memory_ = x.data();
} }
lhs_rhs_element::lhs_rhs_element(atidlas::value_scalar const & x) lhs_rhs_element::lhs_rhs_element(atidlas::value_scalar const & x)
@@ -44,7 +58,7 @@ lhs_rhs_element::lhs_rhs_element(atidlas::repeat_infos const & x)
type_family = INFOS_TYPE_FAMILY; type_family = INFOS_TYPE_FAMILY;
subtype = REPEAT_INFOS_TYPE; subtype = REPEAT_INFOS_TYPE;
dtype = INVALID_NUMERIC_TYPE; dtype = INVALID_NUMERIC_TYPE;
tuple = (atidlas::repeat_infos*)&x; tuple = x;
} }
// //

View File

@@ -43,46 +43,46 @@ void test(T epsilon, simple_matrix_base<T> & cA, simple_matrix_base<T>& cB, simp
std::cout << std::endl;\ std::cout << std::endl;\
} }
// RUN_TEST("C = A", cC(i,j) = cA(i,j), C = A) RUN_TEST("C = A", cC(i,j) = cA(i,j), C = A)
// RUN_TEST("C = A + B", cC(i,j) = cA(i,j) + cB(i,j), C = A + B) RUN_TEST("C = A + B", cC(i,j) = cA(i,j) + cB(i,j), C = A + B)
// RUN_TEST("C = A - B", cC(i,j) = cA(i,j) - cB(i,j), C = A - B) RUN_TEST("C = A - B", cC(i,j) = cA(i,j) - cB(i,j), C = A - B)
// RUN_TEST("C = A + B + C", cC(i,j) = cA(i,j) + cB(i,j) + cC(i,j), C = A + B + C) RUN_TEST("C = A + B + C", cC(i,j) = cA(i,j) + cB(i,j) + cC(i,j), C = A + B + C)
// RUN_TEST("C = a*A", cC(i,j) = aa*cA(i,j), C = a*A) RUN_TEST("C = a*A", cC(i,j) = aa*cA(i,j), C = a*A)
// RUN_TEST("C = da*A", cC(i,j) = aa*cA(i,j), C = da*A) RUN_TEST("C = da*A", cC(i,j) = aa*cA(i,j), C = da*A)
// RUN_TEST("C = a*A + b*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= a*A + b*B) RUN_TEST("C = a*A + b*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= a*A + b*B)
// RUN_TEST("C = da*A + b*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= da*A + b*B) RUN_TEST("C = da*A + b*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= da*A + b*B)
// RUN_TEST("C = a*A + db*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= a*A + db*B) RUN_TEST("C = a*A + db*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= a*A + db*B)
// RUN_TEST("C = da*A + db*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= da*A + db*B) RUN_TEST("C = da*A + db*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= da*A + db*B)
// RUN_TEST("C = exp(A)", cC(i,j) = exp(cA(i,j)), C= exp(A)) RUN_TEST("C = exp(A)", cC(i,j) = exp(cA(i,j)), C= exp(A))
// RUN_TEST("C = abs(A)", cC(i,j) = abs(cA(i,j)), C= abs(A)) RUN_TEST("C = abs(A)", cC(i,j) = abs(cA(i,j)), C= abs(A))
// RUN_TEST("C = acos(A)", cC(i,j) = acos(cA(i,j)), C= acos(A)) RUN_TEST("C = acos(A)", cC(i,j) = acos(cA(i,j)), C= acos(A))
// RUN_TEST("C = asin(A)", cC(i,j) = asin(cA(i,j)), C= asin(A)) RUN_TEST("C = asin(A)", cC(i,j) = asin(cA(i,j)), C= asin(A))
// RUN_TEST("C = atan(A)", cC(i,j) = atan(cA(i,j)), C= atan(A)) RUN_TEST("C = atan(A)", cC(i,j) = atan(cA(i,j)), C= atan(A))
// RUN_TEST("C = ceil(A)", cC(i,j) = ceil(cA(i,j)), C= ceil(A)) RUN_TEST("C = ceil(A)", cC(i,j) = ceil(cA(i,j)), C= ceil(A))
// RUN_TEST("C = cos(A)", cC(i,j) = cos(cA(i,j)), C= cos(A)) RUN_TEST("C = cos(A)", cC(i,j) = cos(cA(i,j)), C= cos(A))
// RUN_TEST("C = cosh(A)", cC(i,j) = cosh(cA(i,j)), C= cosh(A)) RUN_TEST("C = cosh(A)", cC(i,j) = cosh(cA(i,j)), C= cosh(A))
// RUN_TEST("C = floor(A)", cC(i,j) = floor(cA(i,j)), C= floor(A)) RUN_TEST("C = floor(A)", cC(i,j) = floor(cA(i,j)), C= floor(A))
// RUN_TEST("C = log(A)", cC(i,j) = log(cA(i,j)), C= log(A)) RUN_TEST("C = log(A)", cC(i,j) = log(cA(i,j)), C= log(A))
// RUN_TEST("C = log10(A)", cC(i,j) = log10(cA(i,j)), C= log10(A)) RUN_TEST("C = log10(A)", cC(i,j) = log10(cA(i,j)), C= log10(A))
// RUN_TEST("C = sin(A)", cC(i,j) = sin(cA(i,j)), C= sin(A)) RUN_TEST("C = sin(A)", cC(i,j) = sin(cA(i,j)), C= sin(A))
// RUN_TEST("C = sinh(A)", cC(i,j) = sinh(cA(i,j)), C= sinh(A)) RUN_TEST("C = sinh(A)", cC(i,j) = sinh(cA(i,j)), C= sinh(A))
// RUN_TEST("C = sqrt(A)", cC(i,j) = sqrt(cA(i,j)), C= sqrt(A)) RUN_TEST("C = sqrt(A)", cC(i,j) = sqrt(cA(i,j)), C= sqrt(A))
// RUN_TEST("C = tan(A)", cC(i,j) = tan(cA(i,j)), C= tan(A)) RUN_TEST("C = tan(A)", cC(i,j) = tan(cA(i,j)), C= tan(A))
// RUN_TEST("C = tanh(A)", cC(i,j) = tanh(cA(i,j)), C= tanh(A)) RUN_TEST("C = tanh(A)", cC(i,j) = tanh(cA(i,j)), C= tanh(A))
// RUN_TEST("C = A.*B", cC(i,j) = cA(i,j)*cB(i,j), C= A*B) RUN_TEST("C = A.*B", cC(i,j) = cA(i,j)*cB(i,j), C= A*B)
// RUN_TEST("C = A./B", cC(i,j) = cA(i,j)/cB(i,j), C= A/B) RUN_TEST("C = A./B", cC(i,j) = cA(i,j)/cB(i,j), C= A/B)
// RUN_TEST("C = A==B", cC(i,j) = cA(i,j)==cB(i,j), C= A==B) RUN_TEST("C = A==B", cC(i,j) = cA(i,j)==cB(i,j), C= A==B)
// RUN_TEST("C = A>=B", cC(i,j) = cA(i,j)>=cB(i,j), C= A>=B) RUN_TEST("C = A>=B", cC(i,j) = cA(i,j)>=cB(i,j), C= A>=B)
// RUN_TEST("C = A>B", cC(i,j) = cA(i,j)>cB(i,j), C= A>B) RUN_TEST("C = A>B", cC(i,j) = cA(i,j)>cB(i,j), C= A>B)
// RUN_TEST("C = A<=B", cC(i,j) = cA(i,j)<=cB(i,j), C= A<=B) RUN_TEST("C = A<=B", cC(i,j) = cA(i,j)<=cB(i,j), C= A<=B)
// RUN_TEST("C = A<B", cC(i,j) = cA(i,j)<cB(i,j), C= A<B) RUN_TEST("C = A<B", cC(i,j) = cA(i,j)<cB(i,j), C= A<B)
// RUN_TEST("C = A!=B", cC(i,j) = cA(i,j)!=cB(i,j), C= A!=B) RUN_TEST("C = A!=B", cC(i,j) = cA(i,j)!=cB(i,j), C= A!=B)
// RUN_TEST("C = pow(A,B)", cC(i,j) = pow(cA(i,j), cB(i,j)), C= pow(A,B)) RUN_TEST("C = pow(A,B)", cC(i,j) = pow(cA(i,j), cB(i,j)), C= pow(A,B))
// RUN_TEST("C = eye(M, N)", cC(i,j) = i==j, C= eye(M, N, C.dtype())) RUN_TEST("C = eye(M, N)", cC(i,j) = i==j, C= eye(M, N, C.dtype()))
RUN_TEST("C = outer(x, y)", cC(i,j) = cx[i]*cy[j], C= outer(x,y)) RUN_TEST("C = outer(x, y)", cC(i,j) = cx[i]*cy[j], C= outer(x,y))
#undef RUN_TEST #undef RUN_TEST

View File

@@ -57,8 +57,8 @@ void test_impl(T epsilon)
INIT_VECTOR(M, SUBM, 6, 2, cy, y); INIT_VECTOR(M, SUBM, 6, 2, cy, y);
INIT_VECTOR(N, SUBN, 4, 3, cx, x); INIT_VECTOR(N, SUBN, 4, 3, cx, x);
// std::cout << "full..." << std::endl; std::cout << "full..." << std::endl;
// test_row_wise_reduction(epsilon, cy_full, cA_full, cx_full, y_full, A_full, x_full); test_row_wise_reduction(epsilon, cy_full, cA_full, cx_full, y_full, A_full, x_full);
std::cout << "slice..." << std::endl; std::cout << "slice..." << std::endl;
test_row_wise_reduction(epsilon, cy_slice, cA_slice, cx_slice, y_slice, A_slice, x_slice); test_row_wise_reduction(epsilon, cy_slice, cA_slice, cx_slice, y_slice, A_slice, x_slice);
} }