API: polished slice construction

This commit is contained in:
Philippe Tillet
2015-10-03 19:30:50 -04:00
parent b5100f9d9a
commit 740f5def49
10 changed files with 33 additions and 48 deletions

View File

@@ -42,7 +42,7 @@ array::array(std::vector<DT> const & x, driver::Context const & context):
{ *this = x; }
array::array(array & v, slice const & s0) :
dtype_(v.dtype_), shape_(s0.size, 1, 1, 1), start_(v.start_[0] + v.stride_[0]*s0.start, 0, 0, 0), stride_(v.stride_[0]*s0.stride, 1, 1, 1),
dtype_(v.dtype_), shape_(s0.size(v.shape_[0]), 1, 1, 1), start_(v.start_[0] + v.stride_[0]*s0.start, 0, 0, 0), stride_(v.stride_[0]*s0.stride, 1, 1, 1),
ld_(v.ld_), context_(v.context()), data_(v.data_),
T(isaac::trans(*this))
{}
@@ -76,7 +76,7 @@ array::array(int_t shape0, int_t shape1, numeric_type dtype, driver::Buffer data
{ }
array::array(array & M, slice const & s0, slice const & s1) :
dtype_(M.dtype_), shape_(s0.size, s1.size, 1, 1),
dtype_(M.dtype_), shape_(s0.size(M.shape_[0]), s1.size(M.shape_[1]), 1, 1),
start_(M.start_[0] + M.stride_[0]*s0.start, M.start_[1] + M.stride_[1]*s1.start, 0, 0),
stride_(M.stride_[0]*s0.stride, M.stride_[1]*s1.stride, 1, 1), ld_(M.ld_),
context_(M.data_.context()), data_(M.data_),
@@ -101,12 +101,12 @@ array::array(int_t shape0, int_t shape1, int_t shape2, numeric_type dtype, drive
T(isaac::trans(*this))
{}
//Slices
array::array(numeric_type dtype, driver::Buffer data, slice const & s0, slice const & s1, int_t ld):
dtype_(dtype), shape_(s0.size, s1.size), start_(s0.start, s1.start), stride_(s0.stride, s1.stride),
ld_(ld), context_(data.context()), data_(data),
T(isaac::trans(*this))
{ }
////Slices
//array::array(numeric_type dtype, driver::Buffer data, slice const & s0, slice const & s1, int_t ld):
// dtype_(dtype), shape_(s0.size, s1.size), start_(s0.start, s1.start), stride_(s0.stride, s1.stride),
// ld_(ld), context_(data.context()), data_(data),
// T(isaac::trans(*this))
//{ }
@@ -332,7 +332,7 @@ void copy(driver::Context const & context, driver::Buffer const & data, T value)
}
scalar::scalar(numeric_type dtype, const driver::Buffer &data, int_t offset): array(dtype, data, _(offset, offset+1), _(1,2), 1)
scalar::scalar(numeric_type dtype, const driver::Buffer &data, int_t offset): array(1, dtype, data, offset, 1)
{ }
scalar::scalar(value_scalar value, driver::Context const & context) : array(1, value.dtype(), context)