diff --git a/include/isaac/array.h b/include/isaac/array.h index 8a110d762..362d59d6b 100644 --- a/include/isaac/array.h +++ b/include/isaac/array.h @@ -12,13 +12,12 @@ namespace isaac { class scalar; +class view; class ISAACAPI array: public array_base { protected: - //Slices array(numeric_type dtype, driver::Buffer data, slice const & s1, slice const & s2, int_t ld); - public: //1D Constructors explicit array(int_t size1, numeric_type dtype = FLOAT_TYPE, driver::Context const & context = driver::backend::contexts::get_default()); @@ -88,8 +87,8 @@ public: math_expression operator[](for_idx_t idx) const; const scalar operator[](int_t) const; scalar operator[](int_t); - array operator[](slice const &); - array operator()(slice const &, slice const &); + view operator[](slice const &); + view operator()(slice const &, slice const &); protected: @@ -107,6 +106,13 @@ public: math_expression T; }; +class ISAACAPI view : public array +{ +public: + view(array& data, slice const & s1); + view(array& data, slice const & s1, slice const & s2); +}; + class ISAACAPI scalar : public array { friend value_scalar::value_scalar(const scalar &); diff --git a/lib/array.cpp b/lib/array.cpp index c2f62d1e5..dbe2de0d5 100644 --- a/lib/array.cpp +++ b/lib/array.cpp @@ -304,14 +304,20 @@ const scalar array::operator [](int_t idx) const } -array array::operator[](slice const & e1) +view array::operator[](slice const & e1) { assert(nshape()<=1); - return array(*this, e1); + return view(*this, e1); } -array array::operator()(slice const & e1, slice const & e2) -{ return array(*this, e1, e2); } +view array::operator()(slice const & e1, slice const & e2) +{ return view(*this, e1, e2); } + +//--------------------------------------- +/*--- View ---*/ +view::view(array& data, slice const & s1) : array(data, s1) {} +view::view(array& data, slice const & s1, slice const & s2) : array(data, s1, s2) {} + //--------------------------------------- /*--- Scalar ---*/ diff --git a/tests/linalg/common.hpp b/tests/linalg/common.hpp index 25495ef10..26f879295 100644 --- a/tests/linalg/common.hpp +++ b/tests/linalg/common.hpp @@ -187,7 +187,7 @@ bool diff(VecType1 const & x, VecType2 const & y, typename VecType1::value_type simple_vector_slice CPREFIX ## _slice(CPREFIX ## _full, START, START + STRIDE*SUBN, STRIDE);\ init_rand(CPREFIX ## _full);\ isaac::array PREFIX ## _full(CPREFIX ## _full.data(), CTX);\ - isaac::array PREFIX ## _slice = PREFIX ## _full[isaac::_(START, START + STRIDE*SUBN, STRIDE)]; + isaac::view PREFIX ## _slice = PREFIX ## _full[isaac::_(START, START + STRIDE*SUBN, STRIDE)]; #define INIT_MATRIX(M, SUBM, START1, STRIDE1, N, SUBN, START2, STRIDE2, CPREFIX, PREFIX, CTX) \ simple_matrix CPREFIX ## _full(M, N);\ @@ -195,11 +195,11 @@ bool diff(VecType1 const & x, VecType2 const & y, typename VecType1::value_type START2, START2 + STRIDE2*SUBN, STRIDE2);\ init_rand(CPREFIX ## _full);\ isaac::array PREFIX ## _full(M, N, CPREFIX ## _full.data(), CTX);\ - isaac::array PREFIX ## _slice(PREFIX ## _full(isaac::_(START1, START1 + STRIDE1*SUBM, STRIDE1),\ + isaac::view PREFIX ## _slice(PREFIX ## _full(isaac::_(START1, START1 + STRIDE1*SUBM, STRIDE1),\ isaac::_(START2, START2 + STRIDE2*SUBN, STRIDE2)));\ simple_matrix CPREFIX ## T_full = simple_trans(CPREFIX ## _full);\ isaac::array PREFIX ## T_full(N, M, CPREFIX ## T_full.data(), CTX);\ - isaac::array PREFIX ## T_slice(PREFIX ## T_full(isaac::_(START2, START2 + STRIDE2*SUBN, STRIDE2),\ + isaac::view PREFIX ## T_slice(PREFIX ## T_full(isaac::_(START2, START2 + STRIDE2*SUBN, STRIDE2),\ isaac::_(START1, START1 + STRIDE1*SUBM, STRIDE1)));\