low level representation of array

This commit is contained in:
Philippe Tillet
2015-01-18 14:52:45 -05:00
parent 16648f18e0
commit edaa821d93
17 changed files with 243 additions and 194 deletions

View File

@@ -15,6 +15,7 @@ class scalar;
class array: public obj_base
{
friend array reshape(array const &, int_t, int_t);
public:
//1D Constructors
array(int_t size1, numeric_type dtype, cl::Context context = cl::default_context());
@@ -29,9 +30,10 @@ public:
array(array & M, slice const & s1, slice const & s2);
//General constructor
array(numeric_type dtype, cl::Buffer data, slice const & s1, slice const & s2, cl::Context context = cl::default_context());
explicit array(array_expression const & proxy);
array(numeric_type dtype, cl::Buffer data, slice const & s1, slice const & s2, int_t ld, cl::Context context = cl::default_context());
array(array_expression const & proxy);
array(array const &);
//Getters
numeric_type dtype() const;
size4 shape() const;
@@ -45,7 +47,6 @@ public:
//Setters
array& resize(int_t size1, int_t size2=1);
array& reshape(int_t size1, int_t size2=1);
//Numeric operators
array& operator=(array const &);
@@ -114,8 +115,8 @@ public:
atidlas::array_expression eye(std::size_t, std::size_t, atidlas::numeric_type, cl::Context ctx = cl::default_context());
array_expression zeros(std::size_t N, numeric_type dtype);
array reshape(array const &, int_t, int_t);
//copy
@@ -197,13 +198,6 @@ array_expression norm(array_expression const &, unsigned int order = 2);
#undef ATIDLAS_DECLARE_UNARY_OPERATOR
struct repeat_infos
{
repeat_infos(size4 const & _sub, size4 const & _rep) : sub(_sub), rep(_rep){ }
size4 sub;
size4 rep;
};
array_expression repmat(array const &, int_t const & rep1, int_t const & rep2);
#define ATIDLAS_DECLARE_REDUCTION(OPNAME) \

View File

@@ -18,8 +18,8 @@ class symbolic_binder
{
public:
virtual ~symbolic_binder();
virtual bool bind(cl::Buffer const * ph) = 0;
virtual unsigned int get(cl::Buffer const * ph) = 0;
virtual bool bind(cl_mem ph) = 0;
virtual unsigned int get(cl_mem ph) = 0;
};
@@ -27,8 +27,8 @@ class bind_to_handle : public symbolic_binder
{
public:
bind_to_handle();
bool bind(cl::Buffer const * ph);
unsigned int get(cl::Buffer const * ph);
bool bind(cl_mem ph);
unsigned int get(cl_mem ph);
private:
unsigned int current_arg_;
std::map<void*,unsigned int> memory;
@@ -38,8 +38,8 @@ class bind_all_unique : public symbolic_binder
{
public:
bind_all_unique();
bool bind(cl::Buffer const *);
unsigned int get(cl::Buffer const *);
bool bind(cl_mem);
unsigned int get(cl_mem);
private:
unsigned int current_arg_;
std::map<void*,unsigned int> memory;

View File

@@ -139,7 +139,7 @@ void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::strin
class symbolic_expression_representation_functor : public traversal_functor{
private:
static void append_id(char * & ptr, unsigned int val);
void append(cl::Buffer const * h, numeric_type dtype, char prefix) const;
void append(cl_mem h, numeric_type dtype, char prefix) const;
void append(lhs_rhs_element const & lhs_rhs) const;
public:
symbolic_expression_representation_functor(symbolic_binder & binder, char *& ptr);

View File

@@ -79,7 +79,7 @@ protected:
/** @brief Creates a value scalar mapping */
tools::shared_ptr<mapped_object> create(numeric_type dtype, values_holder) const;
/** @brief Creates a vector mapping */
tools::shared_ptr<mapped_object> create(array const &) const;
tools::shared_ptr<mapped_object> create(array_infos const &) const;
/** @brief Creates a tuple mapping */
tools::shared_ptr<mapped_object> create(repeat_infos const &) const;
/** @brief Creates a mapping */
@@ -101,7 +101,7 @@ protected:
set_arguments_functor(symbolic_binder & binder, unsigned int & current_arg, cl::Kernel & kernel);
void set_arguments(numeric_type dtype, values_holder const & scal) const;
void set_arguments(array const & ) const;
void set_arguments(array_infos const & ) const;
void set_arguments(repeat_infos const & i) const;
void set_arguments(lhs_rhs_element const & lhs_rhs) const;

View File

@@ -39,10 +39,10 @@ private:
std::string generate_impl(unsigned int label, char id, const symbolic_expressions_container &symbolic_expressions, const std::vector<mapping_type> &, bool fallback) const;
std::vector<std::string> generate_impl(unsigned int label, symbolic_expressions_container const & symbolic_expressions, std::vector<mapping_type> const & mappings) const;
void enqueue_block(cl::CommandQueue & queue, int_t M, int_t N, int_t K,
array const & A, array const & B, array const & C,
array_infos const & A, array_infos const & B, array_infos const & C,
value_scalar const & alpha, value_scalar const & beta,
std::vector<cl::lazy_compiler> & programs, unsigned int label, int id);
array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap);
array_infos create_slice(array_infos & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap);
std::vector<int_t> infos(symbolic_expressions_container const & symbolic_expressions,
lhs_rhs_element & C, lhs_rhs_element & A, lhs_rhs_element & B);
public:

View File

@@ -133,6 +133,20 @@ enum symbolic_expression_node_subtype
REPEAT_INFOS_TYPE
};
struct array_infos
{
numeric_type dtype;
cl_mem data;
int_t shape1;
int_t shape2;
int_t start1;
int_t start2;
int_t stride1;
int_t stride2;
int_t ld;
};
void fill(array const & a, array_infos& i);
struct lhs_rhs_element
{
lhs_rhs_element();
@@ -144,14 +158,14 @@ struct lhs_rhs_element
symbolic_expression_node_type_family type_family;
symbolic_expression_node_subtype subtype;
numeric_type dtype;
union
{
unsigned int node_index;
atidlas::array * array;
values_holder vscalar;
atidlas::repeat_infos * tuple;
repeat_infos tuple;
array_infos array;
};
cl::Buffer memory_;
};

View File

@@ -6,19 +6,28 @@
namespace atidlas
{
typedef int int_t;
typedef int int_t;
struct size4
{
struct size4
{
size4(int_t s1, int_t s2 = 1) : _1(s1), _2(s2){ }
int_t prod() const { return _1*_2; }
bool operator==(size4 const & other) const { return _1==other._1 && _2==other._2; }
int_t _1;
int_t _2;
};
inline int_t prod(size4 const & s) { return s._1*s._2; }
inline int_t max(size4 const & s) { return std::max(s._1, s._2); }
inline int_t min(size4 const & s) { return std::min(s._1, s._2); }
};
inline int_t prod(size4 const & s) { return s._1*s._2; }
inline int_t max(size4 const & s) { return std::max(s._1, s._2); }
inline int_t min(size4 const & s) { return std::min(s._1, s._2); }
struct repeat_infos
{
int_t sub1;
int_t sub2;
int_t rep1;
int_t rep2;
};
enum numeric_type
{

View File

@@ -75,9 +75,9 @@ INSTANTIATE(cl_double);
#undef INSTANTIATE
// General
array::array(numeric_type dtype, cl::Buffer data, slice const & s1, slice const & s2, cl::Context context):
array::array(numeric_type dtype, cl::Buffer data, slice const & s1, slice const & s2, int_t ld, cl::Context context):
dtype_(dtype), shape_(s1.size, s2.size), start_(s1.start, s2.start), stride_(s1.stride, s2.stride),
ld_(shape_._1), context_(context), data_(data)
ld_(ld), context_(context), data_(data)
{ }
array::array(array_expression const & proxy) :
@@ -88,6 +88,15 @@ array::array(array_expression const & proxy) :
*this = proxy;
}
array::array(array const & other) :
dtype_(other.dtype()),
shape_(other.shape()), start_(0,0), stride_(1, 1), ld_(shape_._1),
context_(other.context()), data_(context_, CL_MEM_READ_WRITE, size_of(dtype_)*dsize())
{
*this = other;
}
/*--- Getters ---*/
numeric_type array::dtype() const
{ return dtype_; }
@@ -116,15 +125,6 @@ cl::Buffer const & array::data() const
int_t array::dsize() const
{ return ld_*shape_._2; }
/*--- Setters ---*/
array& array::reshape(int_t size1, int_t size2)
{
assert(size1*size2==prod(shape_));
shape_ = size4(size1, size2);
ld_ = size1;
return *this;
}
/*--- Assignment Operators ----*/
//---------------------------------------
array & array::operator=(array const & rhs)
@@ -215,7 +215,7 @@ void copy(cl::Context & ctx, cl::Buffer const & data, T value)
}
scalar::scalar(numeric_type dtype, const cl::Buffer &data, int_t offset, cl::Context context): array(dtype, data, _(offset, offset+1), _(1,1), context)
scalar::scalar(numeric_type dtype, const cl::Buffer &data, int_t offset, cl::Context context): array(dtype, data, _(offset, offset+1), _(1,1), 1, context)
{ }
scalar::scalar(value_scalar value, cl::Context context) : array(1, value.dtype(), context)
@@ -457,18 +457,22 @@ array_expression trans(array_expression const & x) \
array_expression repmat(array const & A, int_t const & rep1, int_t const & rep2)
{
static repeat_infos infos(A.shape(), size4(rep1, rep2));
infos = repeat_infos(A.shape(), size4(rep1, rep2));
size4 newshape = prod(infos.sub, infos.rep);
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), newshape);
repeat_infos infos;
infos.rep1 = rep1;
infos.rep2 = rep2;
infos.sub1 = A.shape()._1;
infos.sub2 = A.shape()._2;
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), size4(infos.rep1*infos.sub1, infos.rep2*infos.sub2));
}
array_expression repmat(array_expression const & A, int_t const & rep1, int_t const & rep2)
{
static repeat_infos infos(A.shape(), size4(rep1, rep2));
infos = repeat_infos(A.shape(), size4(rep1, rep2));
size4 newshape = prod(infos.sub, infos.rep);
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), newshape);
repeat_infos infos;
infos.rep1 = rep1;
infos.rep2 = rep2;
infos.sub1 = A.shape()._1;
infos.sub2 = A.shape()._2;
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), size4(infos.rep1*infos.sub1, infos.rep2*infos.sub2));
}
////---------------------------------------
@@ -575,13 +579,12 @@ namespace detail
return res;
}
template<class T>
array_expression matvecprod(array const & A, T const & x)
{
int_t M = A.shape()._1;
int_t N = A.shape()._2;
return sum(A*repmat(const_cast<T&>(x).reshape(1, N), M, 1), 0);
return sum(A*repmat(reshape(x, 1, N), M, 1), 0);
}
template<class T>
@@ -593,13 +596,13 @@ namespace detail
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
if(A_trans)
{
array_expression tmp(A, repmat(const_cast<T&>(x), 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), size4(N, M));
array_expression tmp(A, repmat(x, 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), size4(N, M));
//Remove trans
tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs;
return sum(tmp, 1);
}
else
return sum(A*repmat(const_cast<T&>(x).reshape(1, N), M, 1), 0);
return sum(A*repmat(reshape(x, 1, N), M, 1), 0);
}
@@ -611,6 +614,15 @@ namespace detail
}
array reshape(array const & a, int_t size1, int_t size2)
{
array tmp(a);
tmp.shape_._1 = size1;
tmp.shape_._2 = size2;
return tmp;
}
#define DEFINE_DOT(LTYPE, RTYPE) \
array_expression dot(LTYPE const & x, RTYPE const & y)\
{\

View File

@@ -9,20 +9,20 @@ bind_to_handle::bind_to_handle() : current_arg_(0)
{ }
//
bool bind_to_handle::bind(cl::Buffer const * ph)
bool bind_to_handle::bind(cl_mem ph)
{ return (ph==NULL)?true:memory.insert(std::make_pair((void*)ph, current_arg_)).second; }
unsigned int bind_to_handle::get(cl::Buffer const * ph)
{ return bind(ph)?current_arg_++:memory.at((void*)ph); }
unsigned int bind_to_handle::get(cl_mem ph)
{ return bind(ph)?current_arg_++:memory.at(ph); }
//
bind_all_unique::bind_all_unique() : current_arg_(0)
{ }
bool bind_all_unique::bind(cl::Buffer const *)
bool bind_all_unique::bind(cl_mem)
{return true;}
unsigned int bind_all_unique::get(cl::Buffer const *)
unsigned int bind_all_unique::get(cl_mem)
{ return current_arg_++;}
}

View File

@@ -419,7 +419,7 @@ void symbolic_expression_representation_functor::append_id(char * & ptr, unsigne
}
}
void symbolic_expression_representation_functor::append(cl::Buffer const * h, numeric_type dtype, char prefix) const
void symbolic_expression_representation_functor::append(cl_mem h, numeric_type dtype, char prefix) const
{
*ptr_++=prefix;
*ptr_++=(char)dtype;
@@ -429,7 +429,7 @@ void symbolic_expression_representation_functor::append(cl::Buffer const * h, nu
void symbolic_expression_representation_functor::append(lhs_rhs_element const & lhs_rhs) const
{
if(lhs_rhs.subtype==DENSE_ARRAY_TYPE)
append(&lhs_rhs.array->data(), lhs_rhs.array->dtype(), (char)(((int)'0')+lhs_rhs.array->nshape()));
append(lhs_rhs.array.data, lhs_rhs.array.dtype, (char)(((int)'0')+((int)(lhs_rhs.array.shape1>1) + (int)(lhs_rhs.array.shape2>1))));
}
symbolic_expression_representation_functor::symbolic_expression_representation_functor(symbolic_binder & binder, char *& ptr) : binder_(binder), ptr_(ptr){ }

View File

@@ -42,19 +42,19 @@ tools::shared_ptr<mapped_object> base::map_functor::create(numeric_type dtype, v
}
/** @brief Vector mapping */
tools::shared_ptr<mapped_object> base::map_functor::create(array const & a) const
tools::shared_ptr<mapped_object> base::map_functor::create(array_infos const & a) const
{
std::string dtype = numeric_type_to_string(a.dtype());
unsigned int id = binder_.get(&a.data());
if(max(a.shape())==1)
std::string dtype = numeric_type_to_string(a.dtype);
unsigned int id = binder_.get(a.data);
if(a.shape1==1 && a.shape2==1)
return tools::shared_ptr<mapped_object>(new mapped_scalar(dtype, id));
else
{
//Column vector
if(a.shape()._1>1 && a.shape()._2==1)
if(a.shape1>1 && a.shape2==1)
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'c'));
//Row vector
else if(a.shape()._1==1 && a.shape()._2>1)
else if(a.shape1==1 && a.shape2>1)
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'r'));
//Matrix
else
@@ -72,9 +72,9 @@ tools::shared_ptr<mapped_object> base::map_functor::create(lhs_rhs_element const
{
switch(lhs_rhs.type_family)
{
case INFOS_TYPE_FAMILY: return create(*lhs_rhs.tuple);
case INFOS_TYPE_FAMILY: return create(lhs_rhs.tuple);
case VALUE_TYPE_FAMILY: return create(lhs_rhs.dtype, lhs_rhs.vscalar);
case ARRAY_TYPE_FAMILY: return create(*lhs_rhs.array);
case ARRAY_TYPE_FAMILY: return create(lhs_rhs.array);
default: throw "";
}
}
@@ -136,32 +136,32 @@ void base::set_arguments_functor::set_arguments(numeric_type dtype, values_holde
}
/** @brief Vector mapping */
void base::set_arguments_functor::set_arguments(array const & x) const
void base::set_arguments_functor::set_arguments(array_infos const & x) const
{
bool is_bound = binder_.bind(&x.data());
bool is_bound = binder_.bind(x.data);
if (is_bound)
{
//scalar
if(x.nshape()==0)
if(x.shape1==1 && x.shape2==1)
{
kernel_.setArg(current_arg_++, x.data());
kernel_.setArg(current_arg_++, x.data);
}
//array
else
{
kernel_.setArg(current_arg_++, x.data());
if(x.nshape()==1)
kernel_.setArg(current_arg_++, x.data);
if(x.shape1==1 || x.shape2==1)
{
kernel_.setArg(current_arg_++, cl_uint(max(x.start())));
kernel_.setArg(current_arg_++, cl_uint(max(x.stride())));
kernel_.setArg(current_arg_++, cl_uint(std::max(x.start1, x.start2)));
kernel_.setArg(current_arg_++, cl_uint(std::max(x.stride1, x.stride2)));
}
else
{
kernel_.setArg(current_arg_++, cl_uint(x.ld()));
kernel_.setArg(current_arg_++, cl_uint(x.start()._1));
kernel_.setArg(current_arg_++, cl_uint(x.start()._2));
kernel_.setArg(current_arg_++, cl_uint(x.stride()._1));
kernel_.setArg(current_arg_++, cl_uint(x.stride()._2));
kernel_.setArg(current_arg_++, cl_uint(x.ld));
kernel_.setArg(current_arg_++, cl_uint(x.start1));
kernel_.setArg(current_arg_++, cl_uint(x.start2));
kernel_.setArg(current_arg_++, cl_uint(x.stride1));
kernel_.setArg(current_arg_++, cl_uint(x.stride2));
}
}
}
@@ -169,10 +169,10 @@ void base::set_arguments_functor::set_arguments(array const & x) const
void base::set_arguments_functor::set_arguments(repeat_infos const & i) const
{
kernel_.setArg(current_arg_++, cl_uint(i.sub._1));
kernel_.setArg(current_arg_++, cl_uint(i.sub._2));
kernel_.setArg(current_arg_++, cl_uint(i.rep._1));
kernel_.setArg(current_arg_++, cl_uint(i.rep._2));
kernel_.setArg(current_arg_++, cl_uint(i.sub1));
kernel_.setArg(current_arg_++, cl_uint(i.sub2));
kernel_.setArg(current_arg_++, cl_uint(i.rep1));
kernel_.setArg(current_arg_++, cl_uint(i.rep2));
}
void base::set_arguments_functor::set_arguments(lhs_rhs_element const & lhs_rhs) const
@@ -180,8 +180,8 @@ void base::set_arguments_functor::set_arguments(lhs_rhs_element const & lhs_rhs)
switch(lhs_rhs.type_family)
{
case VALUE_TYPE_FAMILY: return set_arguments(lhs_rhs.dtype, lhs_rhs.vscalar);
case ARRAY_TYPE_FAMILY: return set_arguments(*lhs_rhs.array);
case INFOS_TYPE_FAMILY: return set_arguments(*lhs_rhs.tuple);
case ARRAY_TYPE_FAMILY: return set_arguments(lhs_rhs.array);
case INFOS_TYPE_FAMILY: return set_arguments(lhs_rhs.tuple);
default: throw "oh noez";
}
}
@@ -376,7 +376,7 @@ bool base::has_strided_access(symbolic_expressions_container const & symbolic_ex
{
std::vector<lhs_rhs_element> arrays = filter_elements(DENSE_ARRAY_TYPE, **it);
for (std::vector<lhs_rhs_element>::iterator itt = arrays.begin(); itt != arrays.end(); ++itt)
if(max(itt->array->stride())>1)
if(std::max(itt->array.stride1, itt->array.stride2)>1)
return true;
if(filter_nodes(&is_strided, **it, true).empty()==false)
return true;
@@ -388,13 +388,13 @@ int_t base::vector_size(symbolic_expression_node const & node)
{
using namespace tools;
if (node.op.type==OPERATOR_MATRIX_DIAG_TYPE)
return std::min<int_t>(node.lhs.array->shape()._1, node.lhs.array->shape()._2);
return std::min<int_t>(node.lhs.array.shape1, node.lhs.array.shape2);
else if (node.op.type==OPERATOR_MATRIX_ROW_TYPE)
return node.lhs.array->shape()._2;
return node.lhs.array.shape2;
else if (node.op.type==OPERATOR_MATRIX_COLUMN_TYPE)
return node.lhs.array->shape()._1;
return node.lhs.array.shape1;
else
return max(node.lhs.array->shape());
return std::max(node.lhs.array.shape1, node.lhs.array.shape2);
}
@@ -402,13 +402,13 @@ std::pair<int_t, int_t> base::matrix_size(symbolic_expression_node const & node)
{
if (node.op.type==OPERATOR_VDIAG_TYPE)
{
int_t size = node.lhs.array->shape()._1;
int_t size = node.lhs.array.shape1;
return std::make_pair(size,size);
}
else if(node.op.type==OPERATOR_REPEAT_TYPE)
return std::make_pair(node.lhs.array->shape()._1*node.rhs.tuple->rep._1, node.lhs.array->shape()._2*node.rhs.tuple->rep._2);
return std::make_pair(node.lhs.array.shape1*node.rhs.tuple.rep1, node.lhs.array.shape2*node.rhs.tuple.rep2);
else
return std::make_pair(node.lhs.array->shape()._1,node.lhs.array->shape()._2);
return std::make_pair(node.lhs.array.shape1,node.lhs.array.shape2);
}
void base::element_wise_loop_1D(kernel_generation_stream & stream, loop_body_base const & loop_body,
@@ -482,7 +482,7 @@ unsigned int base::align(unsigned int to_round, unsigned int base)
return (to_round + base - 1)/base * base;
}
inline tools::shared_ptr<symbolic_binder> base::make_binder()
tools::shared_ptr<symbolic_binder> base::make_binder()
{
if (binding_policy_==BIND_TO_HANDLE)
return tools::shared_ptr<symbolic_binder>(new bind_to_handle());
@@ -531,7 +531,7 @@ bool base_impl<TType, PType>::has_misaligned_offset(symbolic_expressions_contain
{
std::vector<lhs_rhs_element> arrays = filter_elements(DENSE_ARRAY_TYPE, **it);
for (std::vector<lhs_rhs_element>::iterator itt = arrays.begin(); itt != arrays.end(); ++itt)
if (max(itt->array->start())>0)
if (itt->array.start1>0 || itt->array.start2>0)
return true;
}
return false;

View File

@@ -35,6 +35,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
stream.inc_tab();
process(stream, PARENT_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("scalar", "#scalartype #namereg = *#pointer;")
("array1", "#pointer += #start;")
("array2", "#pointer = &$VALUE{#start1, #start2};"), symbolic_expressions, mappings);
fetching_loop_info(p_.fetching_policy, "M", stream, init0, upper_bound0, inc0, "get_global_id(0)", "get_global_size(0)");

View File

@@ -562,11 +562,11 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
}
void mproduct::enqueue_block(cl::CommandQueue & queue, int_t M, int_t N, int_t K,
array const & A, array const & B, array const & C,
array_infos const & A, array_infos const & B, array_infos const & C,
value_scalar const & alpha, value_scalar const & beta,
std::vector<cl::lazy_compiler> & programs, unsigned int label, int id)
{
if (min(A.shape())==0 || min(B.shape())==0 || min(C.shape())==0)
if (A.shape1==0 || A.shape2==0 || B.shape1==0 || B.shape2==0 || C.shape1==0 || C.shape2==0)
return;
char kname[10];
@@ -594,13 +594,21 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
queue.enqueueNDRangeKernel(kernel, cl::NullRange, grange, lrange);
}
array mproduct::create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap)
array_infos mproduct::create_slice(array_infos & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap)
{
slice s0(s0_0, s0_1);
slice s1(s1_0, s1_1);
slice s1(s0_0, s0_1);
slice s2(s1_0, s1_1);
if (swap)
std::swap(s0, s1);
return array(M, s0, s1);
std::swap(s1, s2);
array_infos result = M;
result.shape1 = s1.size;
result.shape2 = s2.size;
result.start1 = M.start1 + M.stride1*s1.start;
result.start2 = M.start2 + M.stride2*s2.start;
result.stride1 = s1.stride*M.stride1;
result.stride2 = s2.stride*M.stride2;
return result;
}
std::vector<int_t> mproduct::infos(symbolic_expressions_container const & symbolic_expressions,
@@ -618,14 +626,15 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
int_t B_idx = array[root].rhs.node_index;
B = array[B_idx].rhs;
int_t M = C.array->shape()._1;
int_t N = C.array->shape()._2;
int_t K = (A_trans_=='T')?A.array->shape()._1:A.array->shape()._2;
int_t M = C.array.shape1;
int_t N = C.array.shape2;
int_t K = (A_trans_=='T')?A.array.shape1:A.array.shape2;
return tools::make_vector<int_t>() << M << N << K;
}
mproduct::mproduct(mproduct_parameters const & parameters, char A_trans, char B_trans) : base_impl<mproduct, mproduct_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans){ }
mproduct::mproduct(mproduct_parameters const & parameters, char A_trans, char B_trans) : base_impl<mproduct, mproduct_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans)
{ }
std::vector<int_t> mproduct::input_sizes(symbolic_expressions_container const & symbolic_expressions)
{
@@ -633,10 +642,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return infos(symbolic_expressions, d0, d1, d2);
}
void mproduct::enqueue(cl::CommandQueue & queue,
std::vector<cl::lazy_compiler> & programs,
unsigned int label,
symbolic_expressions_container const & symbolic_expressions)
void mproduct::enqueue(cl::CommandQueue & queue, std::vector<cl::lazy_compiler> & programs, unsigned int label, symbolic_expressions_container const & symbolic_expressions)
{
using namespace tools;
@@ -648,15 +654,15 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
int_t K = MNK[2];
array* pA = A.array;
array* pB = B.array;
array* pC = C.array;
array_infos& pA = A.array;
array_infos& pB = B.array;
array_infos& pC = C.array;
int_t ldstrideA = pA->stride()._1;
int_t ldstrideB = pB->stride()._1;
int_t ldstrideC = pC->stride()._1;
int_t ldstartA = pA->start()._1;
int_t ldstartB = pB->start()._1;
int_t ldstrideA = pA.stride1;
int_t ldstrideB = pB.stride1;
int_t ldstrideC = pC.stride1;
int_t ldstartA = pA.start1;
int_t ldstartB = pB.start1;
bool swap_A = (A_trans_=='T');
bool swap_B = (B_trans_=='T');
@@ -669,13 +675,12 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
value_scalar _1d(cl_double(1));
value_scalar* _1 = C.dtype==FLOAT_TYPE?&_1f:&_1d;
if (M < p_.mL || N < p_.nL || K < p_.kL ||
ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1 ||
(p_.simd_width>1 && (ldstartA % p_.simd_width > 0 || ldstartB % p_.simd_width > 0)))
if (M < p_.mL || N < p_.nL || K < p_.kL || ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1
|| (p_.simd_width>1 && (ldstartA % p_.simd_width > 0 || ldstartB % p_.simd_width > 0)))
{
enqueue_block(queue, M, N, K, create_slice(*pA, 0, M, 0, K, swap_A),
create_slice(*pB, 0, K, 0, N, swap_B),
create_slice(*pC, 0, M, 0, N, false), *_1, *_0, programs, label, 1);
enqueue_block(queue, M, N, K, create_slice(pA, 0, M, 0, K, swap_A),
create_slice(pB, 0, K, 0, N, swap_B),
create_slice(pC, 0, M, 0, N, false), *_1, *_0, programs, label, 1);
return;
}
@@ -683,17 +688,17 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
int_t lN = N / p_.nL * p_.nL;
int_t lK = K / p_.kL * p_.kL;
enqueue_block(queue, lM, lN, lK, create_slice(*pA, 0, lM, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), *_1, *_0, programs, label, 0);
enqueue_block(queue, lM, lN, K - lK, create_slice(*pA, 0, lM, lK, K, swap_A), create_slice(*pB, lK, K, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), *_1, *_1, programs, label, 1);
enqueue_block(queue, lM, lN, lK, create_slice(pA, 0, lM, 0, lK, swap_A), create_slice(pB, 0, lK, 0, lN, swap_B), create_slice(pC, 0, lM, 0, lN, false), *_1, *_0, programs, label, 0);
enqueue_block(queue, lM, lN, K - lK, create_slice(pA, 0, lM, lK, K, swap_A), create_slice(pB, lK, K, 0, lN, swap_B), create_slice(pC, 0, lM, 0, lN, false), *_1, *_1, programs, label, 1);
enqueue_block(queue, lM, N - lN, lK, create_slice(*pA, 0, lM, 0, lK, swap_A), create_slice(*pB, 0, lK, lN, N, swap_B), create_slice(*pC, 0, lM, lN, N, false), *_1, *_0, programs, label, 1);
enqueue_block(queue, lM, N - lN, K - lK, create_slice(*pA, 0, lM, lK, K, swap_A), create_slice(*pB, lK, K, lN, N, swap_B), create_slice(*pC, 0, lM, lN, N, false), *_1, *_1, programs, label, 1);
enqueue_block(queue, lM, N - lN, lK, create_slice(pA, 0, lM, 0, lK, swap_A), create_slice(pB, 0, lK, lN, N, swap_B), create_slice(pC, 0, lM, lN, N, false), *_1, *_0, programs, label, 1);
enqueue_block(queue, lM, N - lN, K - lK, create_slice(pA, 0, lM, lK, K, swap_A), create_slice(pB, lK, K, lN, N, swap_B), create_slice(pC, 0, lM, lN, N, false), *_1, *_1, programs, label, 1);
enqueue_block(queue, M - lM, lN, lK, create_slice(*pA, lM, M, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, lN, swap_B), create_slice(*pC, lM, M, 0, lN, false), *_1, *_0, programs, label, 1);
enqueue_block(queue, M - lM, lN, K - lK, create_slice(*pA, lM, M, lK, K, swap_A), create_slice(*pB, lK, K, 0, lN, swap_B), create_slice(*pC, lM, M, 0, lN, false), *_1, *_1, programs, label, 1);
enqueue_block(queue, M - lM, lN, lK, create_slice(pA, lM, M, 0, lK, swap_A), create_slice(pB, 0, lK, 0, lN, swap_B), create_slice(pC, lM, M, 0, lN, false), *_1, *_0, programs, label, 1);
enqueue_block(queue, M - lM, lN, K - lK, create_slice(pA, lM, M, lK, K, swap_A), create_slice(pB, lK, K, 0, lN, swap_B), create_slice(pC, lM, M, 0, lN, false), *_1, *_1, programs, label, 1);
enqueue_block(queue, M - lM, N - lN, lK, create_slice(*pA, lM, M, 0, lK, swap_A), create_slice(*pB, 0, lK, lN, N, swap_B), create_slice(*pC, lM, M, lN, N, false), *_1, *_0, programs, label, 1);
enqueue_block(queue, M - lM, N - lN, K - lK, create_slice(*pA, lM, M, lK, K, swap_A), create_slice(*pB, lK, K, lN, N, swap_B), create_slice(*pC, lM, M, lN, N, false), *_1, *_1, programs, label, 1);
enqueue_block(queue, M - lM, N - lN, lK, create_slice(pA, lM, M, 0, lK, swap_A), create_slice(pB, 0, lK, lN, N, swap_B), create_slice(pC, lM, M, lN, N, false), *_1, *_0, programs, label, 1);
enqueue_block(queue, M - lM, N - lN, K - lK, create_slice(pA, lM, M, lK, K, swap_A), create_slice(pB, lK, K, lN, N, swap_B), create_slice(pC, lM, M, lN, N, false), *_1, *_1, programs, label, 1);
}
//

View File

@@ -161,9 +161,9 @@ namespace atidlas
//Init
expression_type current_type;
if(root_save.lhs.array->nshape()==0)
if(root_save.lhs.array.shape1==1 && root_save.lhs.array.shape2==1)
current_type = SCALAR_AXPY_TYPE;
else if(root_save.lhs.array->nshape()==1)
else if(root_save.lhs.array.shape1==1 || root_save.lhs.array.shape2==1)
current_type=VECTOR_AXPY_TYPE;
else
current_type=MATRIX_AXPY_TYPE;
@@ -186,15 +186,15 @@ namespace atidlas
case SCALAR_AXPY_TYPE:
case REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(1, dtype, context)); break;
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._1, dtype, context)); break;
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._1, dtype, context)); break;
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._2, dtype, context)); break;
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape1, dtype, context)); break;
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape1, dtype, context)); break;
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape2, dtype, context)); break;
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._1, lmost.lhs.array->shape()._2, dtype, context)); break;
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._1, node.rhs.array->shape()._2, dtype, context)); break;
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._1, node.rhs.array->shape()._1, dtype, context)); break;
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._2, node.rhs.array->shape()._2, dtype, context)); break;
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._2, node.rhs.array->shape()._1, dtype, context)); break;
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape1, lmost.lhs.array.shape2, dtype, context)); break;
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape1, node.rhs.array.shape2, dtype, context)); break;
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape1, node.rhs.array.shape1, dtype, context)); break;
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape2, node.rhs.array.shape2, dtype, context)); break;
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape2, node.rhs.array.shape1, dtype, context)); break;
default: throw "This shouldn't happen. Ever.";
}
@@ -213,7 +213,7 @@ namespace atidlas
rit->second->dtype = dtype;
rit->second->type_family = ARRAY_TYPE_FAMILY;
rit->second->subtype = DENSE_ARRAY_TYPE;
rit->second->array = (array*)tmp.get();
fill((array&)*tmp, rit->second->array);
}
/*-----Compute final expression-----*/

View File

@@ -8,6 +8,19 @@
namespace atidlas
{
void fill(array const & a, array_infos& i)
{
i.dtype = a.dtype();
i.data = a.data()();
i.shape1 = a.shape()._1;
i.shape2 = a.shape()._2;
i.start1 = a.start()._1;
i.start2 = a.start()._2;
i.stride1 = a.stride()._1;
i.stride2 = a.stride()._2;
i.ld = a.ld();
}
lhs_rhs_element::lhs_rhs_element()
{
type_family = INVALID_TYPE_FAMILY;
@@ -28,7 +41,8 @@ lhs_rhs_element::lhs_rhs_element(atidlas::array const & x)
type_family = ARRAY_TYPE_FAMILY;
subtype = DENSE_ARRAY_TYPE;
dtype = x.dtype();
array = (atidlas::array*)&x;
fill(x, array);
memory_ = x.data();
}
lhs_rhs_element::lhs_rhs_element(atidlas::value_scalar const & x)
@@ -44,7 +58,7 @@ lhs_rhs_element::lhs_rhs_element(atidlas::repeat_infos const & x)
type_family = INFOS_TYPE_FAMILY;
subtype = REPEAT_INFOS_TYPE;
dtype = INVALID_NUMERIC_TYPE;
tuple = (atidlas::repeat_infos*)&x;
tuple = x;
}
//

View File

@@ -43,46 +43,46 @@ void test(T epsilon, simple_matrix_base<T> & cA, simple_matrix_base<T>& cB, simp
std::cout << std::endl;\
}
// RUN_TEST("C = A", cC(i,j) = cA(i,j), C = A)
// RUN_TEST("C = A + B", cC(i,j) = cA(i,j) + cB(i,j), C = A + B)
// RUN_TEST("C = A - B", cC(i,j) = cA(i,j) - cB(i,j), C = A - B)
// RUN_TEST("C = A + B + C", cC(i,j) = cA(i,j) + cB(i,j) + cC(i,j), C = A + B + C)
RUN_TEST("C = A", cC(i,j) = cA(i,j), C = A)
RUN_TEST("C = A + B", cC(i,j) = cA(i,j) + cB(i,j), C = A + B)
RUN_TEST("C = A - B", cC(i,j) = cA(i,j) - cB(i,j), C = A - B)
RUN_TEST("C = A + B + C", cC(i,j) = cA(i,j) + cB(i,j) + cC(i,j), C = A + B + C)
// RUN_TEST("C = a*A", cC(i,j) = aa*cA(i,j), C = a*A)
// RUN_TEST("C = da*A", cC(i,j) = aa*cA(i,j), C = da*A)
// RUN_TEST("C = a*A + b*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= a*A + b*B)
// RUN_TEST("C = da*A + b*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= da*A + b*B)
// RUN_TEST("C = a*A + db*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= a*A + db*B)
// RUN_TEST("C = da*A + db*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= da*A + db*B)
RUN_TEST("C = a*A", cC(i,j) = aa*cA(i,j), C = a*A)
RUN_TEST("C = da*A", cC(i,j) = aa*cA(i,j), C = da*A)
RUN_TEST("C = a*A + b*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= a*A + b*B)
RUN_TEST("C = da*A + b*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= da*A + b*B)
RUN_TEST("C = a*A + db*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= a*A + db*B)
RUN_TEST("C = da*A + db*B", cC(i,j) = aa*cA(i,j) + bb*cB(i,j), C= da*A + db*B)
// RUN_TEST("C = exp(A)", cC(i,j) = exp(cA(i,j)), C= exp(A))
// RUN_TEST("C = abs(A)", cC(i,j) = abs(cA(i,j)), C= abs(A))
// RUN_TEST("C = acos(A)", cC(i,j) = acos(cA(i,j)), C= acos(A))
// RUN_TEST("C = asin(A)", cC(i,j) = asin(cA(i,j)), C= asin(A))
// RUN_TEST("C = atan(A)", cC(i,j) = atan(cA(i,j)), C= atan(A))
// RUN_TEST("C = ceil(A)", cC(i,j) = ceil(cA(i,j)), C= ceil(A))
// RUN_TEST("C = cos(A)", cC(i,j) = cos(cA(i,j)), C= cos(A))
// RUN_TEST("C = cosh(A)", cC(i,j) = cosh(cA(i,j)), C= cosh(A))
// RUN_TEST("C = floor(A)", cC(i,j) = floor(cA(i,j)), C= floor(A))
// RUN_TEST("C = log(A)", cC(i,j) = log(cA(i,j)), C= log(A))
// RUN_TEST("C = log10(A)", cC(i,j) = log10(cA(i,j)), C= log10(A))
// RUN_TEST("C = sin(A)", cC(i,j) = sin(cA(i,j)), C= sin(A))
// RUN_TEST("C = sinh(A)", cC(i,j) = sinh(cA(i,j)), C= sinh(A))
// RUN_TEST("C = sqrt(A)", cC(i,j) = sqrt(cA(i,j)), C= sqrt(A))
// RUN_TEST("C = tan(A)", cC(i,j) = tan(cA(i,j)), C= tan(A))
// RUN_TEST("C = tanh(A)", cC(i,j) = tanh(cA(i,j)), C= tanh(A))
RUN_TEST("C = exp(A)", cC(i,j) = exp(cA(i,j)), C= exp(A))
RUN_TEST("C = abs(A)", cC(i,j) = abs(cA(i,j)), C= abs(A))
RUN_TEST("C = acos(A)", cC(i,j) = acos(cA(i,j)), C= acos(A))
RUN_TEST("C = asin(A)", cC(i,j) = asin(cA(i,j)), C= asin(A))
RUN_TEST("C = atan(A)", cC(i,j) = atan(cA(i,j)), C= atan(A))
RUN_TEST("C = ceil(A)", cC(i,j) = ceil(cA(i,j)), C= ceil(A))
RUN_TEST("C = cos(A)", cC(i,j) = cos(cA(i,j)), C= cos(A))
RUN_TEST("C = cosh(A)", cC(i,j) = cosh(cA(i,j)), C= cosh(A))
RUN_TEST("C = floor(A)", cC(i,j) = floor(cA(i,j)), C= floor(A))
RUN_TEST("C = log(A)", cC(i,j) = log(cA(i,j)), C= log(A))
RUN_TEST("C = log10(A)", cC(i,j) = log10(cA(i,j)), C= log10(A))
RUN_TEST("C = sin(A)", cC(i,j) = sin(cA(i,j)), C= sin(A))
RUN_TEST("C = sinh(A)", cC(i,j) = sinh(cA(i,j)), C= sinh(A))
RUN_TEST("C = sqrt(A)", cC(i,j) = sqrt(cA(i,j)), C= sqrt(A))
RUN_TEST("C = tan(A)", cC(i,j) = tan(cA(i,j)), C= tan(A))
RUN_TEST("C = tanh(A)", cC(i,j) = tanh(cA(i,j)), C= tanh(A))
// RUN_TEST("C = A.*B", cC(i,j) = cA(i,j)*cB(i,j), C= A*B)
// RUN_TEST("C = A./B", cC(i,j) = cA(i,j)/cB(i,j), C= A/B)
// RUN_TEST("C = A==B", cC(i,j) = cA(i,j)==cB(i,j), C= A==B)
// RUN_TEST("C = A>=B", cC(i,j) = cA(i,j)>=cB(i,j), C= A>=B)
// RUN_TEST("C = A>B", cC(i,j) = cA(i,j)>cB(i,j), C= A>B)
// RUN_TEST("C = A<=B", cC(i,j) = cA(i,j)<=cB(i,j), C= A<=B)
// RUN_TEST("C = A<B", cC(i,j) = cA(i,j)<cB(i,j), C= A<B)
// RUN_TEST("C = A!=B", cC(i,j) = cA(i,j)!=cB(i,j), C= A!=B)
// RUN_TEST("C = pow(A,B)", cC(i,j) = pow(cA(i,j), cB(i,j)), C= pow(A,B))
RUN_TEST("C = A.*B", cC(i,j) = cA(i,j)*cB(i,j), C= A*B)
RUN_TEST("C = A./B", cC(i,j) = cA(i,j)/cB(i,j), C= A/B)
RUN_TEST("C = A==B", cC(i,j) = cA(i,j)==cB(i,j), C= A==B)
RUN_TEST("C = A>=B", cC(i,j) = cA(i,j)>=cB(i,j), C= A>=B)
RUN_TEST("C = A>B", cC(i,j) = cA(i,j)>cB(i,j), C= A>B)
RUN_TEST("C = A<=B", cC(i,j) = cA(i,j)<=cB(i,j), C= A<=B)
RUN_TEST("C = A<B", cC(i,j) = cA(i,j)<cB(i,j), C= A<B)
RUN_TEST("C = A!=B", cC(i,j) = cA(i,j)!=cB(i,j), C= A!=B)
RUN_TEST("C = pow(A,B)", cC(i,j) = pow(cA(i,j), cB(i,j)), C= pow(A,B))
// RUN_TEST("C = eye(M, N)", cC(i,j) = i==j, C= eye(M, N, C.dtype()))
RUN_TEST("C = eye(M, N)", cC(i,j) = i==j, C= eye(M, N, C.dtype()))
RUN_TEST("C = outer(x, y)", cC(i,j) = cx[i]*cy[j], C= outer(x,y))
#undef RUN_TEST

View File

@@ -57,8 +57,8 @@ void test_impl(T epsilon)
INIT_VECTOR(M, SUBM, 6, 2, cy, y);
INIT_VECTOR(N, SUBN, 4, 3, cx, x);
// std::cout << "full..." << std::endl;
// test_row_wise_reduction(epsilon, cy_full, cA_full, cx_full, y_full, A_full, x_full);
std::cout << "full..." << std::endl;
test_row_wise_reduction(epsilon, cy_full, cA_full, cx_full, y_full, A_full, x_full);
std::cout << "slice..." << std::endl;
test_row_wise_reduction(epsilon, cy_slice, cA_slice, cx_slice, y_slice, A_slice, x_slice);
}