Files
triton/include/atidlas/array.h

232 lines
7.7 KiB
C
Raw Normal View History

#ifndef ATIDLAS_ARRAY_H_
#define ATIDLAS_ARRAY_H_
#include <iostream>
2015-02-05 04:42:57 -05:00
#include <type_traits>
2015-01-27 16:14:02 -05:00
#include <CL/cl.hpp>
#include "atidlas/types.h"
2015-02-04 22:06:15 -05:00
#include "atidlas/cl_ext/backend.h"
#include "atidlas/symbolic/expression.h"
namespace atidlas
{
2015-01-13 01:17:27 -05:00
class scalar;
2015-02-04 22:06:15 -05:00
class array: public array_base
{
2015-01-18 14:52:45 -05:00
friend array reshape(array const &, int_t, int_t);
2015-02-05 04:42:57 -05:00
template<class T>
struct is_array { enum{ value = std::is_same<T, array>::value || std::is_same<T, array_expression>::value}; };
public:
//1D Constructors
2015-01-27 16:14:02 -05:00
array(int_t size1, numeric_type dtype, cl::Context context = cl_ext::default_context());
2015-01-21 20:08:52 -05:00
template<typename DT>
2015-01-27 16:14:02 -05:00
array(std::vector<DT> const & data, cl::Context context = cl_ext::default_context());
array(array & v, slice const & s1);
//2D Constructors
2015-01-27 16:14:02 -05:00
array(int_t size1, int_t size2, numeric_type dtype, cl::Context context = cl_ext::default_context());
2015-01-21 20:08:52 -05:00
template<typename DT>
2015-01-27 16:14:02 -05:00
array(int_t size1, int_t size2, std::vector<DT> const & data, cl::Context context = cl_ext::default_context());
array(array & M, slice const & s1, slice const & s2);
//General constructor
2015-01-27 16:14:02 -05:00
array(numeric_type dtype, cl::Buffer data, slice const & s1, slice const & s2, int_t ld, cl::Context context = cl_ext::default_context());
2015-02-04 22:06:15 -05:00
array(array_expression const & proxy);
2015-01-18 14:52:45 -05:00
array(array const &);
2015-01-19 21:29:47 -05:00
//Getters
numeric_type dtype() const;
size4 shape() const;
2015-02-04 22:06:15 -05:00
int_t nshape() const;
size4 start() const;
size4 stride() const;
int_t ld() const;
cl::Context const & context() const;
cl::Buffer const & data() const;
int_t dsize() const;
//Setters
2015-01-16 07:31:39 -05:00
array& resize(int_t size1, int_t size2=1);
//Numeric operators
array& operator=(array const &);
2015-02-04 22:06:15 -05:00
array& operator=(array_expression const &);
2015-02-05 04:42:57 -05:00
template<class T>
array& operator=(controller<T> const &);
template<class T>
array & operator=(std::vector<T> const & rhs);
2015-01-19 21:29:47 -05:00
array_expression operator-();
2015-01-29 15:19:40 -05:00
array_expression operator!();
2015-01-16 07:31:39 -05:00
array& operator+=(value_scalar const &);
array& operator+=(array const &);
array& operator+=(array_expression const &);
array& operator-=(value_scalar const &);
array& operator-=(array const &);
array& operator-=(array_expression const &);
array& operator*=(value_scalar const &);
array& operator*=(array const &);
array& operator*=(array_expression const &);
array& operator/=(value_scalar const &);
array& operator/=(array const &);
array& operator/=(array_expression const &);
//Indexing operators
const scalar operator[](int_t) const;
2015-01-13 01:17:27 -05:00
scalar operator[](int_t);
array operator[](slice const &);
array operator()(slice const &, slice const &);
2015-01-21 20:08:52 -05:00
array_expression T() const;
protected:
2015-02-04 22:06:15 -05:00
numeric_type dtype_;
size4 shape_;
size4 start_;
size4 stride_;
int_t ld_;
cl::Context context_;
cl::Buffer data_;
};
class scalar : public array
{
2015-01-13 01:17:27 -05:00
private:
template<class T> T cast() const;
public:
2015-01-27 16:14:02 -05:00
explicit scalar(numeric_type dtype, cl::Buffer const & data, int_t offset, cl::Context context = cl_ext::default_context());
explicit scalar(value_scalar value, cl::Context context = cl_ext::default_context());
explicit scalar(numeric_type dtype, cl::Context context = cl_ext::default_context());
2015-02-04 22:06:15 -05:00
scalar(array_expression const & proxy);
2015-01-16 07:31:39 -05:00
scalar& operator=(value_scalar const &);
2015-01-19 21:29:47 -05:00
// scalar& operator=(scalar const & s);
using array::operator =;
2015-01-13 01:17:27 -05:00
#define INSTANTIATE(type) operator type() const;
2015-01-16 07:31:39 -05:00
INSTANTIATE(bool)
2015-01-13 01:17:27 -05:00
INSTANTIATE(cl_char)
INSTANTIATE(cl_uchar)
INSTANTIATE(cl_short)
INSTANTIATE(cl_ushort)
INSTANTIATE(cl_int)
INSTANTIATE(cl_uint)
INSTANTIATE(cl_long)
INSTANTIATE(cl_ulong)
INSTANTIATE(cl_float)
INSTANTIATE(cl_double)
#undef INSTANTIATE
};
2015-01-16 07:31:39 -05:00
2015-01-21 20:08:52 -05:00
2015-01-16 07:31:39 -05:00
//copy
void copy(void const * data, array & gx, cl::CommandQueue & queue, bool blocking = true);
void copy(array const & gx, void* data, cl::CommandQueue & queue, bool blocking = true);
void copy(void const *data, array &gx, bool blocking = true);
void copy(array const & gx, void* data, bool blocking = true);
template<class T> void copy(std::vector<T> const & cA, array& gA, cl::CommandQueue & queue, bool blocking = true);
template<class T> void copy(array const & gA, std::vector<T> & cA, cl::CommandQueue & queue, bool blocking = true);
template<class T> void copy(std::vector<T> const & cA, array & gA, bool blocking = true);
template<class T> void copy(array const & gA, std::vector<T> & cA, bool blocking = true);
//Operators
//Binary operators
#define ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(OPNAME) \
array_expression OPNAME (array_expression const & x, array_expression const & y);\
array_expression OPNAME (array const & x, array_expression const & y);\
array_expression OPNAME (array_expression const & x, array const & y);\
array_expression OPNAME (array const & x, array const & y);\
array_expression OPNAME (array_expression const & x, value_scalar const & y);\
array_expression OPNAME (array const & x, value_scalar const & y);\
array_expression OPNAME (value_scalar const & y, array_expression const & x);\
array_expression OPNAME (value_scalar const & y, array const & x);
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator +)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator -)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator *)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator /)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator >)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator >=)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator <)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator <=)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator ==)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator !=)
2015-01-20 11:17:42 -05:00
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(maximum)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(minimum)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(pow)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(dot)
2015-01-16 07:31:39 -05:00
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(outer)
namespace detail
{
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(assign)
}
#undef ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR
//--------------------------------
//Unary operators
#define ATIDLAS_DECLARE_UNARY_OPERATOR(OPNAME) \
array_expression OPNAME (array const & x);\
array_expression OPNAME (array_expression const & x);
ATIDLAS_DECLARE_UNARY_OPERATOR(abs)
ATIDLAS_DECLARE_UNARY_OPERATOR(acos)
ATIDLAS_DECLARE_UNARY_OPERATOR(asin)
ATIDLAS_DECLARE_UNARY_OPERATOR(atan)
ATIDLAS_DECLARE_UNARY_OPERATOR(ceil)
ATIDLAS_DECLARE_UNARY_OPERATOR(cos)
ATIDLAS_DECLARE_UNARY_OPERATOR(cosh)
ATIDLAS_DECLARE_UNARY_OPERATOR(exp)
ATIDLAS_DECLARE_UNARY_OPERATOR(floor)
ATIDLAS_DECLARE_UNARY_OPERATOR(log)
ATIDLAS_DECLARE_UNARY_OPERATOR(log10)
ATIDLAS_DECLARE_UNARY_OPERATOR(sin)
ATIDLAS_DECLARE_UNARY_OPERATOR(sinh)
ATIDLAS_DECLARE_UNARY_OPERATOR(sqrt)
ATIDLAS_DECLARE_UNARY_OPERATOR(tan)
ATIDLAS_DECLARE_UNARY_OPERATOR(tanh)
ATIDLAS_DECLARE_UNARY_OPERATOR(trans)
array_expression cast(array const &, numeric_type dtype);
array_expression cast(array_expression const &, numeric_type dtype);
2015-01-16 07:31:39 -05:00
array_expression norm(array const &, unsigned int order = 2);
array_expression norm(array_expression const &, unsigned int order = 2);
#undef ATIDLAS_DECLARE_UNARY_OPERATOR
array_expression repmat(array const &, int_t const & rep1, int_t const & rep2);
#define ATIDLAS_DECLARE_REDUCTION(OPNAME) \
array_expression OPNAME(array const & M, int_t axis = -1);\
array_expression OPNAME(array_expression const & M, int_t axis = -1);
ATIDLAS_DECLARE_REDUCTION(sum)
ATIDLAS_DECLARE_REDUCTION(argmax)
ATIDLAS_DECLARE_REDUCTION(max)
ATIDLAS_DECLARE_REDUCTION(min)
ATIDLAS_DECLARE_REDUCTION(argmin)
2015-01-27 16:14:02 -05:00
atidlas::array_expression eye(std::size_t, std::size_t, atidlas::numeric_type, cl::Context ctx = cl_ext::default_context());
array_expression zeros(std::size_t M, std::size_t N, numeric_type dtype, cl::Context ctx = cl_ext::default_context());
2015-01-21 20:08:52 -05:00
array reshape(array const &, int_t, int_t);
//
std::ostream& operator<<(std::ostream &, array const &);
2015-01-19 21:29:47 -05:00
std::ostream& operator<<(std::ostream & os, scalar const & s);
}
#endif