API: clearer interface for transposition

This commit is contained in:
Philippe Tillet
2015-10-01 17:23:26 -04:00
parent feeb1e9862
commit 1e076c131b
5 changed files with 50 additions and 32 deletions

View File

@@ -24,22 +24,27 @@ namespace isaac
array::array(int_t shape0, numeric_type dtype, driver::Context const & context) :
dtype_(dtype), shape_(shape0, 1, 1, 1), start_(0, 0, 0, 0), stride_(1, 1, 1, 1), ld_(shape_[0]),
context_(context), data_(context_, size_of(dtype)*dsize())
context_(context), data_(context_, size_of(dtype)*dsize()),
T(isaac::trans(*this))
{ }
array::array(int_t shape0, numeric_type dtype, driver::Buffer data, int_t start, int_t inc):
dtype_(dtype), shape_(shape0), start_(start, 0, 0, 0), stride_(inc), ld_(shape_[0]), context_(data.context()), data_(data)
dtype_(dtype), shape_(shape0), start_(start, 0, 0, 0), stride_(inc), ld_(shape_[0]), context_(data.context()), data_(data),
T(isaac::trans(*this))
{ }
template<class DT>
array::array(std::vector<DT> const & x, driver::Context const & context):
dtype_(to_numeric_type<DT>::value), shape_((int_t)x.size(), 1), start_(0, 0, 0, 0), stride_(1, 1, 1, 1), ld_(shape_[0]),
context_(context), data_(context, size_of(dtype_)*dsize())
context_(context), data_(context, size_of(dtype_)*dsize()),
T(isaac::trans(*this))
{ *this = x; }
array::array(array & v, slice const & s0) : dtype_(v.dtype_), shape_(s0.size, 1, 1, 1), start_(v.start_[0] + v.stride_[0]*s0.start, 0, 0, 0), stride_(v.stride_[0]*s0.stride, 1, 1, 1),
ld_(v.ld_), context_(v.context()), data_(v.data_)
array::array(array & v, slice const & s0) :
dtype_(v.dtype_), shape_(s0.size, 1, 1, 1), start_(v.start_[0] + v.stride_[0]*s0.start, 0, 0, 0), stride_(v.stride_[0]*s0.stride, 1, 1, 1),
ld_(v.ld_), context_(v.context()), data_(v.data_),
T(isaac::trans(*this))
{}
#define INSTANTIATE(T) template ISAACAPI array::array(std::vector<T> const &, driver::Context const &)
@@ -58,18 +63,24 @@ INSTANTIATE(double);
#undef INSTANTIATE
// 2D
array::array(int_t shape0, int_t shape1, numeric_type dtype, driver::Context const & context) : dtype_(dtype), shape_(shape0, shape1), start_(0, 0, 0, 0), stride_(1, 1, 1, 1), ld_(shape0),
context_(context), data_(context_, size_of(dtype_)*dsize())
array::array(int_t shape0, int_t shape1, numeric_type dtype, driver::Context const & context) :
dtype_(dtype), shape_(shape0, shape1), start_(0, 0, 0, 0), stride_(1, 1, 1, 1), ld_(shape0),
context_(context), data_(context_, size_of(dtype_)*dsize()),
T(isaac::trans(*this))
{}
array::array(int_t shape0, int_t shape1, numeric_type dtype, driver::Buffer data, int_t start, int_t ld) :
dtype_(dtype), shape_(shape0, shape1), start_(start, 0, 0, 0), stride_(1, 1, 1, 1), ld_(ld), context_(data.context()), data_(data)
dtype_(dtype), shape_(shape0, shape1), start_(start, 0, 0, 0), stride_(1, 1, 1, 1),
ld_(ld), context_(data.context()), data_(data),
T(isaac::trans(*this))
{ }
array::array(array & M, slice const & s0, slice const & s1) : dtype_(M.dtype_), shape_(s0.size, s1.size, 1, 1),
start_(M.start_[0] + M.stride_[0]*s0.start, M.start_[1] + M.stride_[1]*s1.start, 0, 0),
stride_(M.stride_[0]*s0.stride, M.stride_[1]*s1.stride, 1, 1), ld_(M.ld_),
context_(M.data_.context()), data_(M.data_)
array::array(array & M, slice const & s0, slice const & s1) :
dtype_(M.dtype_), shape_(s0.size, s1.size, 1, 1),
start_(M.start_[0] + M.stride_[0]*s0.start, M.start_[1] + M.stride_[1]*s1.start, 0, 0),
stride_(M.stride_[0]*s0.stride, M.stride_[1]*s1.stride, 1, 1), ld_(M.ld_),
context_(M.data_.context()), data_(M.data_),
T(isaac::trans(*this))
{ }
@@ -77,20 +88,24 @@ template<typename DT>
array::array(int_t shape0, int_t shape1, std::vector<DT> const & data, driver::Context const & context)
: dtype_(to_numeric_type<DT>::value),
shape_(shape0, shape1), start_(0, 0), stride_(1, 1), ld_(shape0),
context_(context), data_(context_, size_of(dtype_)*dsize())
context_(context), data_(context_, size_of(dtype_)*dsize()),
T(isaac::trans(*this))
{
isaac::copy(data, *this);
}
// 3D
array::array(int_t shape0, int_t shape1, int_t shape2, numeric_type dtype, driver::Context const & context) : dtype_(dtype), shape_(shape0, shape1, shape2, 1), start_(0, 0, 0, 0), stride_(1, 1, 1, 1), ld_(shape0),
context_(context), data_(context_, size_of(dtype_)*dsize())
array::array(int_t shape0, int_t shape1, int_t shape2, numeric_type dtype, driver::Context const & context) :
dtype_(dtype), shape_(shape0, shape1, shape2, 1), start_(0, 0, 0, 0), stride_(1, 1, 1, 1), ld_(shape0),
context_(context), data_(context_, size_of(dtype_)*dsize()),
T(isaac::trans(*this))
{}
//Slices
array::array(numeric_type dtype, driver::Buffer data, slice const & s0, slice const & s1, int_t ld):
dtype_(dtype), shape_(s0.size, s1.size), start_(s0.start, s1.start), stride_(s0.stride, s1.stride),
ld_(ld), context_(data.context()), data_(data)
ld_(ld), context_(data.context()), data_(data),
T(isaac::trans(*this))
{ }
@@ -112,9 +127,11 @@ INSTANTIATE(double);
array::array(math_expression const & proxy) : array(execution_handler(proxy)){}
array::array(array const & other): dtype_(other.dtype()),
array::array(array const & other):
dtype_(other.dtype()),
shape_(other.shape()), start_(0,0), stride_(1, 1), ld_(shape_[0]),
context_(other.context()), data_(context_, size_of(dtype_)*dsize())
context_(other.context()), data_(context_, size_of(dtype_)*dsize()),
T(isaac::trans(*this))
{
*this = other;
}
@@ -122,7 +139,8 @@ array::array(array const & other): dtype_(other.dtype()),
array::array(execution_handler const & other) :
dtype_(other.x().dtype()),
shape_(other.x().shape()), start_(0,0), stride_(1, 1), ld_(shape_[0]),
context_(other.x().context()), data_(context_, size_of(dtype_)*dsize())
context_(other.x().context()), data_(context_, size_of(dtype_)*dsize()),
T(isaac::trans(*this))
{
*this = other;
}
@@ -266,9 +284,6 @@ array & array::operator/=(array const & rhs)
array & array::operator/=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), rhs.context(), dtype_, shape_); }
math_expression array::T() const
{ return isaac::trans(*this) ;}
/*--- Indexing operators -----*/
//---------------------------------------
math_expression array::operator[](for_idx_t idx) const