Feature: Merged kernel-fusion branch
* Fuses multiple AXPY kernel * Possibility to add thread-wise for loops in AXPY-like kernels
This commit is contained in:
370
lib/array.cpp
370
lib/array.cpp
@@ -5,10 +5,11 @@
|
||||
#include <stdexcept>
|
||||
|
||||
#include "isaac/array.h"
|
||||
#include "isaac/tuple.h"
|
||||
#include "isaac/exception/unknown_datatype.h"
|
||||
#include "isaac/profiles/profiles.h"
|
||||
#include "isaac/symbolic/execute.h"
|
||||
|
||||
#include "isaac/symbolic/io.h"
|
||||
|
||||
namespace isaac
|
||||
{
|
||||
@@ -109,11 +110,16 @@ INSTANTIATE(float);
|
||||
INSTANTIATE(double);
|
||||
#undef INSTANTIATE
|
||||
|
||||
array::array(array_expression const & proxy) : array(control(proxy)){}
|
||||
array::array(array const & other) : array(control(other)){}
|
||||
array::array(math_expression const & proxy) : array(execution_handler(proxy)){}
|
||||
|
||||
template<class TYPE>
|
||||
array::array(controller<TYPE> const & other) :
|
||||
array::array(array const & other): dtype_(other.dtype()),
|
||||
shape_(other.shape()), start_(0,0), stride_(1, 1), ld_(shape_[0]),
|
||||
context_(other.context()), data_(context_, size_of(dtype_)*dsize())
|
||||
{
|
||||
*this = other;
|
||||
}
|
||||
|
||||
array::array(execution_handler const & other) :
|
||||
dtype_(other.x().dtype()),
|
||||
shape_(other.x().shape()), start_(0,0), stride_(1, 1), ld_(shape_[0]),
|
||||
context_(other.x().context()), data_(context_, size_of(dtype_)*dsize())
|
||||
@@ -121,13 +127,11 @@ array::array(controller<TYPE> const & other) :
|
||||
*this = other;
|
||||
}
|
||||
|
||||
|
||||
template ISAACAPI array::array(controller<array> const&);
|
||||
template ISAACAPI array::array(controller<array_expression> const&);
|
||||
|
||||
/*--- Getters ---*/
|
||||
numeric_type array::dtype() const
|
||||
{ return dtype_; }
|
||||
{
|
||||
return dtype_;
|
||||
}
|
||||
|
||||
size4 const & array::shape() const
|
||||
{ return shape_; }
|
||||
@@ -159,26 +163,37 @@ int_t array::dsize() const
|
||||
|
||||
/*--- Assignment Operators ----*/
|
||||
//---------------------------------------
|
||||
|
||||
array & array::operator=(array const & rhs)
|
||||
{ return *this = controller<array>(rhs); }
|
||||
{
|
||||
assert(dtype_ == rhs.dtype());
|
||||
math_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
execute(execution_handler(expression));
|
||||
return *this;
|
||||
}
|
||||
|
||||
array & array::operator=(array_expression const & rhs)
|
||||
{ return *this = controller<array_expression>(rhs); }
|
||||
array & array::operator=(value_scalar const & rhs)
|
||||
{
|
||||
assert(dtype_ == rhs.dtype());
|
||||
math_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
execute(execution_handler(expression));
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<class TYPE>
|
||||
array& array::operator=(controller<TYPE> const & c)
|
||||
|
||||
array& array::operator=(execution_handler const & c)
|
||||
{
|
||||
assert(dtype_ == c.x().dtype());
|
||||
array_expression expression(*this, c.x(), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
execute(controller<array_expression>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()),
|
||||
isaac::profiles::get(c.execution_options().queue(context_)));
|
||||
math_expression expression(*this, c.x(), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
||||
return *this;
|
||||
}
|
||||
|
||||
#define INSTANTIATE(TYPE) template ISAACAPI array& array::operator=<TYPE>(controller<TYPE> const &)
|
||||
INSTANTIATE(array);
|
||||
INSTANTIATE(array_expression);
|
||||
#undef INSTANTIATE
|
||||
array & array::operator=(math_expression const & rhs)
|
||||
{
|
||||
return *this = execution_handler(rhs);
|
||||
}
|
||||
|
||||
|
||||
template<class DT>
|
||||
array & array::operator=(std::vector<DT> const & rhs)
|
||||
@@ -204,60 +219,63 @@ INSTANTIATE(float);
|
||||
INSTANTIATE(double);
|
||||
#undef INSTANTIATE
|
||||
|
||||
array & array::operator=(value_scalar const & rhs)
|
||||
{ return *this = controller<value_scalar>(rhs); }
|
||||
|
||||
|
||||
|
||||
|
||||
array_expression array::operator-()
|
||||
{ return array_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
|
||||
math_expression array::operator-()
|
||||
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_expression array::operator!()
|
||||
{ return array_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_NEGATE_TYPE), context_, INT_TYPE, shape_); }
|
||||
math_expression array::operator!()
|
||||
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_NEGATE_TYPE), context_, INT_TYPE, shape_); }
|
||||
|
||||
//
|
||||
array & array::operator+=(value_scalar const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array & array::operator+=(array const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array & array::operator+=(array_expression const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), rhs.context(), dtype_, shape_); }
|
||||
array & array::operator+=(math_expression const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), rhs.context(), dtype_, shape_); }
|
||||
//----
|
||||
array & array::operator-=(value_scalar const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array & array::operator-=(array const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array & array::operator-=(array_expression const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), rhs.context(), dtype_, shape_); }
|
||||
array & array::operator-=(math_expression const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), rhs.context(), dtype_, shape_); }
|
||||
//----
|
||||
array & array::operator*=(value_scalar const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array & array::operator*=(array const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array & array::operator*=(array_expression const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), rhs.context(), dtype_, shape_); }
|
||||
array & array::operator*=(math_expression const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), rhs.context(), dtype_, shape_); }
|
||||
//----
|
||||
array & array::operator/=(value_scalar const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array & array::operator/=(array const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array & array::operator/=(array_expression const & rhs)
|
||||
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), rhs.context(), dtype_, shape_); }
|
||||
array & array::operator/=(math_expression const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), rhs.context(), dtype_, shape_); }
|
||||
|
||||
array_expression array::T() const
|
||||
math_expression array::T() const
|
||||
{ return isaac::trans(*this) ;}
|
||||
|
||||
/*--- Indexing operators -----*/
|
||||
//---------------------------------------
|
||||
math_expression array::operator[](for_idx_t idx) const
|
||||
{
|
||||
return math_expression(*this, idx, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ACCESS_INDEX_TYPE), context_, dtype_, shape_);
|
||||
}
|
||||
|
||||
scalar array::operator [](int_t idx)
|
||||
{
|
||||
assert(nshape()<=1);
|
||||
@@ -318,7 +336,7 @@ scalar::scalar(value_scalar value, driver::Context const & context) : array(1, v
|
||||
scalar::scalar(numeric_type dtype, driver::Context const & context) : array(1, dtype, context)
|
||||
{ }
|
||||
|
||||
scalar::scalar(array_expression const & proxy) : array(proxy){ }
|
||||
scalar::scalar(math_expression const & proxy) : array(proxy){ }
|
||||
|
||||
void scalar::inject(values_holder & v) const
|
||||
{
|
||||
@@ -445,37 +463,63 @@ size4 elementwise_size(U const & u, V const & v)
|
||||
template<class U, class V>
|
||||
bool check_elementwise(U const & u, V const & v)
|
||||
{
|
||||
return true;
|
||||
return detail::max(u.shape())==1 || detail::max(v.shape())==1 || u.shape()==v.shape();
|
||||
}
|
||||
|
||||
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \
|
||||
array_expression OPNAME (array_expression const & x, array_expression const & y) \
|
||||
math_expression OPNAME (array const & x, math_expression const & y) \
|
||||
{ assert(check_elementwise(x, y));\
|
||||
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, elementwise_size(x, y)); } \
|
||||
return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, elementwise_size(x, y)); } \
|
||||
\
|
||||
math_expression OPNAME (array const & x, array const & y) \
|
||||
{ assert(check_elementwise(x, y));\
|
||||
return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, elementwise_size(x, y)); }\
|
||||
\
|
||||
math_expression OPNAME (array const & x, value_scalar const & y) \
|
||||
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
\
|
||||
math_expression OPNAME (array const & x, for_idx_t const & y) \
|
||||
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
\
|
||||
\
|
||||
math_expression OPNAME (math_expression const & x, math_expression const & y) \
|
||||
{ assert(check_elementwise(x, y));\
|
||||
return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, elementwise_size(x, y)); } \
|
||||
\
|
||||
array_expression OPNAME (array const & x, array_expression const & y) \
|
||||
math_expression OPNAME (math_expression const & x, array const & y) \
|
||||
{ assert(check_elementwise(x, y));\
|
||||
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, elementwise_size(x, y)); } \
|
||||
return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, elementwise_size(x, y)); } \
|
||||
\
|
||||
array_expression OPNAME (array_expression const & x, array const & y) \
|
||||
{ assert(check_elementwise(x, y));\
|
||||
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, elementwise_size(x, y)); } \
|
||||
math_expression OPNAME (math_expression const & x, value_scalar const & y) \
|
||||
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
\
|
||||
array_expression OPNAME (array const & x, array const & y) \
|
||||
{ assert(check_elementwise(x, y));\
|
||||
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, elementwise_size(x, y)); }\
|
||||
math_expression OPNAME (math_expression const & x, for_idx_t const & y) \
|
||||
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
\
|
||||
array_expression OPNAME (array_expression const & x, value_scalar const & y) \
|
||||
{ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
\
|
||||
array_expression OPNAME (array const & x, value_scalar const & y) \
|
||||
{ return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
math_expression OPNAME (value_scalar const & y, math_expression const & x) \
|
||||
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
\
|
||||
array_expression OPNAME (value_scalar const & y, array_expression const & x) \
|
||||
{ return array_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
math_expression OPNAME (value_scalar const & y, array const & x) \
|
||||
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
\
|
||||
array_expression OPNAME (value_scalar const & y, array const & x) \
|
||||
{ return array_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }
|
||||
math_expression OPNAME (value_scalar const & x, for_idx_t const & y) \
|
||||
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE); }\
|
||||
\
|
||||
\
|
||||
math_expression OPNAME (for_idx_t const & y, math_expression const & x) \
|
||||
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
\
|
||||
math_expression OPNAME (for_idx_t const & y, value_scalar const & x) \
|
||||
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE); } \
|
||||
\
|
||||
math_expression OPNAME (for_idx_t const & y, array const & x) \
|
||||
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
\
|
||||
math_expression OPNAME (for_idx_t const & y, for_idx_t const & x) \
|
||||
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP)); }
|
||||
|
||||
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ADD_TYPE, operator +, x.dtype())
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_SUB_TYPE, operator -, x.dtype())
|
||||
@@ -497,28 +541,50 @@ DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==, INT_TYPE)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=, INT_TYPE)
|
||||
|
||||
#define DEFINE_OUTER(LTYPE, RTYPE) \
|
||||
array_expression outer(LTYPE const & x, RTYPE const & y)\
|
||||
math_expression outer(LTYPE const & x, RTYPE const & y)\
|
||||
{\
|
||||
assert(x.nshape()==1 && y.nshape()==1);\
|
||||
return array_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_OUTER_PROD_TYPE), x.context(), x.dtype(), size4(detail::max(x.shape()), detail::max(y.shape())) );\
|
||||
return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_OUTER_PROD_TYPE), x.context(), x.dtype(), size4(detail::max(x.shape()), detail::max(y.shape())) );\
|
||||
}\
|
||||
|
||||
DEFINE_OUTER(array, array)
|
||||
DEFINE_OUTER(array_expression, array)
|
||||
DEFINE_OUTER(array, array_expression)
|
||||
DEFINE_OUTER(array_expression, array_expression)
|
||||
DEFINE_OUTER(math_expression, array)
|
||||
DEFINE_OUTER(array, math_expression)
|
||||
DEFINE_OUTER(math_expression, math_expression)
|
||||
|
||||
#undef DEFINE_ELEMENT_BINARY_OPERATOR
|
||||
|
||||
#define DEFINE_ROT(LTYPE, RTYPE, CTYPE, STYPE)\
|
||||
math_expression rot(LTYPE const & x, RTYPE const & y, CTYPE const & c, STYPE const & s)\
|
||||
{ return fuse(assign(x, c*x + s*y), assign(y, c*y - s*x)); }
|
||||
|
||||
DEFINE_ROT(array, array, scalar, scalar)
|
||||
DEFINE_ROT(math_expression, array, scalar, scalar)
|
||||
DEFINE_ROT(array, math_expression, scalar, scalar)
|
||||
DEFINE_ROT(math_expression, math_expression, scalar, scalar)
|
||||
|
||||
DEFINE_ROT(array, array, value_scalar, value_scalar)
|
||||
DEFINE_ROT(math_expression, array, value_scalar, value_scalar)
|
||||
DEFINE_ROT(array, math_expression, value_scalar, value_scalar)
|
||||
DEFINE_ROT(math_expression, math_expression, value_scalar, value_scalar)
|
||||
|
||||
DEFINE_ROT(array, array, math_expression, math_expression)
|
||||
DEFINE_ROT(math_expression, array, math_expression, math_expression)
|
||||
DEFINE_ROT(array, math_expression, math_expression, math_expression)
|
||||
DEFINE_ROT(math_expression, math_expression, math_expression, math_expression)
|
||||
|
||||
|
||||
|
||||
//---------------------------------------
|
||||
|
||||
/*--- Math Operators----*/
|
||||
//---------------------------------------
|
||||
#define DEFINE_ELEMENT_UNARY_OPERATOR(OP, OPNAME) \
|
||||
array_expression OPNAME (array const & x) \
|
||||
{ return array_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
|
||||
math_expression OPNAME (array const & x) \
|
||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
|
||||
\
|
||||
array_expression OPNAME (array_expression const & x) \
|
||||
{ return array_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
|
||||
math_expression OPNAME (math_expression const & x) \
|
||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
|
||||
|
||||
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?OPERATOR_FABS_TYPE:OPERATOR_ABS_TYPE, abs)
|
||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ACOS_TYPE, acos)
|
||||
@@ -562,17 +628,17 @@ inline operation_node_type casted(numeric_type dtype)
|
||||
}
|
||||
}
|
||||
|
||||
array_expression cast(array const & x, numeric_type dtype)
|
||||
{ return array_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||
math_expression cast(array const & x, numeric_type dtype)
|
||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||
|
||||
array_expression cast(array_expression const & x, numeric_type dtype)
|
||||
{ return array_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||
math_expression cast(math_expression const & x, numeric_type dtype)
|
||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||
|
||||
isaac::array_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||
{ return array_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N)); }
|
||||
isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||
{ return math_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N)); }
|
||||
|
||||
isaac::array_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||
{ return array_expression(value_scalar(0, dtype), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N)); }
|
||||
isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N)); }
|
||||
|
||||
inline size4 flip(size4 const & shape)
|
||||
{ return size4(shape[1], shape[0]);}
|
||||
@@ -580,59 +646,77 @@ inline size4 flip(size4 const & shape)
|
||||
inline size4 prod(size4 const & shape1, size4 const & shape2)
|
||||
{ return size4(shape1[0]*shape2[0], shape1[1]*shape2[1]);}
|
||||
|
||||
array_expression trans(array const & x) \
|
||||
{ return array_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
|
||||
math_expression trans(array const & x) \
|
||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
|
||||
\
|
||||
array_expression trans(array_expression const & x) \
|
||||
{ return array_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
|
||||
math_expression trans(math_expression const & x) \
|
||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
|
||||
|
||||
array_expression repmat(array const & A, int_t const & rep1, int_t const & rep2)
|
||||
math_expression repmat(array const & A, int_t const & rep1, int_t const & rep2)
|
||||
{
|
||||
repeat_infos infos;
|
||||
infos.rep1 = rep1;
|
||||
infos.rep2 = rep2;
|
||||
infos.sub1 = A.shape()[0];
|
||||
infos.sub2 = A.shape()[1];
|
||||
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), size4(infos.rep1*infos.sub1, infos.rep2*infos.sub2));
|
||||
int_t sub1 = A.shape()[0];
|
||||
int_t sub2 = A.shape()[1];
|
||||
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), size4(rep1*sub1, rep2*sub2));
|
||||
}
|
||||
|
||||
array_expression repmat(array_expression const & A, int_t const & rep1, int_t const & rep2)
|
||||
math_expression repmat(math_expression const & A, int_t const & rep1, int_t const & rep2)
|
||||
{
|
||||
repeat_infos infos;
|
||||
infos.rep1 = rep1;
|
||||
infos.rep2 = rep2;
|
||||
infos.sub1 = A.shape()[0];
|
||||
infos.sub2 = A.shape()[1];
|
||||
return array_expression(A, infos, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), size4(infos.rep1*infos.sub1, infos.rep2*infos.sub2));
|
||||
int_t sub1 = A.shape()[0];
|
||||
int_t sub2 = A.shape()[1];
|
||||
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), size4(rep1*sub1, rep2*sub2));
|
||||
}
|
||||
|
||||
#define DEFINE_ACCESS_ROW(TYPEA, TYPEB) \
|
||||
math_expression row(TYPEA const & x, TYPEB const & i)\
|
||||
{ return math_expression(x, i, op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_MATRIX_ROW_TYPE), x.context(), x.dtype(), size4(x.shape()[1], 1)); }
|
||||
|
||||
DEFINE_ACCESS_ROW(array, value_scalar)
|
||||
DEFINE_ACCESS_ROW(array, for_idx_t)
|
||||
DEFINE_ACCESS_ROW(array, math_expression)
|
||||
|
||||
DEFINE_ACCESS_ROW(math_expression, value_scalar)
|
||||
DEFINE_ACCESS_ROW(math_expression, for_idx_t)
|
||||
DEFINE_ACCESS_ROW(math_expression, math_expression)
|
||||
|
||||
#define DEFINE_ACCESS_COL(TYPEA, TYPEB) \
|
||||
math_expression col(TYPEA const & x, TYPEB const & i)\
|
||||
{ return math_expression(x, i, op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_MATRIX_COLUMN_TYPE), x.context(), x.dtype(), size4(x.shape()[1], 1)); }
|
||||
|
||||
DEFINE_ACCESS_COL(array, value_scalar)
|
||||
DEFINE_ACCESS_COL(array, for_idx_t)
|
||||
DEFINE_ACCESS_COL(array, math_expression)
|
||||
|
||||
DEFINE_ACCESS_COL(math_expression, value_scalar)
|
||||
DEFINE_ACCESS_COL(math_expression, for_idx_t)
|
||||
DEFINE_ACCESS_COL(math_expression, math_expression)
|
||||
|
||||
////---------------------------------------
|
||||
|
||||
///*--- Reductions ---*/
|
||||
////---------------------------------------
|
||||
#define DEFINE_DOT(OP, OPNAME)\
|
||||
array_expression OPNAME(array const & x, int_t axis)\
|
||||
math_expression OPNAME(array const & x, int_t axis)\
|
||||
{\
|
||||
if(axis < -1 || axis > x.nshape())\
|
||||
throw std::out_of_range("The axis entry is out of bounds");\
|
||||
else if(axis==-1)\
|
||||
return array_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\
|
||||
return math_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\
|
||||
else if(axis==0)\
|
||||
return array_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\
|
||||
return math_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\
|
||||
else\
|
||||
return array_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\
|
||||
return math_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\
|
||||
}\
|
||||
\
|
||||
array_expression OPNAME(array_expression const & x, int_t axis)\
|
||||
math_expression OPNAME(math_expression const & x, int_t axis)\
|
||||
{\
|
||||
if(axis < -1 || axis > x.nshape())\
|
||||
throw std::out_of_range("The axis entry is out of bounds");\
|
||||
if(axis==-1)\
|
||||
return array_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\
|
||||
return math_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\
|
||||
else if(axis==0)\
|
||||
return array_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\
|
||||
return math_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\
|
||||
else\
|
||||
return array_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\
|
||||
return math_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\
|
||||
}
|
||||
|
||||
DEFINE_DOT(OPERATOR_ADD_TYPE, sum)
|
||||
@@ -646,51 +730,51 @@ DEFINE_DOT(OPERATOR_ELEMENT_ARGMIN_TYPE, argmin)
|
||||
namespace detail
|
||||
{
|
||||
|
||||
array_expression matmatprod(array const & A, array const & B)
|
||||
math_expression matmatprod(array const & A, array const & B)
|
||||
{
|
||||
size4 shape(A.shape()[0], B.shape()[1]);
|
||||
return array_expression(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, OPERATOR_GEMM_NN_TYPE), A.context(), A.dtype(), shape);
|
||||
return math_expression(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, OPERATOR_GEMM_NN_TYPE), A.context(), A.dtype(), shape);
|
||||
}
|
||||
|
||||
array_expression matmatprod(array_expression const & A, array const & B)
|
||||
math_expression matmatprod(math_expression const & A, array const & B)
|
||||
{
|
||||
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
|
||||
size4 shape(A.shape()[0], B.shape()[1]);
|
||||
|
||||
array_expression::node & A_root = const_cast<array_expression::node &>(A.tree()[A.root()]);
|
||||
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
||||
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
|
||||
if(A_trans){
|
||||
type = OPERATOR_GEMM_TN_TYPE;
|
||||
}
|
||||
|
||||
array_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
|
||||
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
||||
if(A_trans) res_root.lhs = A_root.lhs;
|
||||
return res;
|
||||
}
|
||||
|
||||
array_expression matmatprod(array const & A, array_expression const & B)
|
||||
math_expression matmatprod(array const & A, math_expression const & B)
|
||||
{
|
||||
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
|
||||
size4 shape(A.shape()[0], B.shape()[1]);
|
||||
|
||||
array_expression::node & B_root = const_cast<array_expression::node &>(B.tree()[B.root()]);
|
||||
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
|
||||
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
|
||||
if(B_trans){
|
||||
type = OPERATOR_GEMM_NT_TYPE;
|
||||
}
|
||||
|
||||
array_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
|
||||
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
||||
if(B_trans) res_root.rhs = B_root.lhs;
|
||||
return res;
|
||||
}
|
||||
|
||||
array_expression matmatprod(array_expression const & A, array_expression const & B)
|
||||
math_expression matmatprod(math_expression const & A, math_expression const & B)
|
||||
{
|
||||
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
|
||||
array_expression::node & A_root = const_cast<array_expression::node &>(A.tree()[A.root()]);
|
||||
array_expression::node & B_root = const_cast<array_expression::node &>(B.tree()[B.root()]);
|
||||
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
||||
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
|
||||
size4 shape(A.shape()[0], B.shape()[1]);
|
||||
|
||||
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
|
||||
@@ -701,15 +785,15 @@ namespace detail
|
||||
else if(!A_trans && B_trans) type = OPERATOR_GEMM_NT_TYPE;
|
||||
else type = OPERATOR_GEMM_NN_TYPE;
|
||||
|
||||
array_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
|
||||
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
||||
if(A_trans) res_root.lhs = A_root.lhs;
|
||||
if(B_trans) res_root.rhs = B_root.lhs;
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
array_expression matvecprod(array const & A, T const & x)
|
||||
math_expression matvecprod(array const & A, T const & x)
|
||||
{
|
||||
int_t M = A.shape()[0];
|
||||
int_t N = A.shape()[1];
|
||||
@@ -717,11 +801,11 @@ namespace detail
|
||||
}
|
||||
|
||||
template<class T>
|
||||
array_expression matvecprod(array_expression const & A, T const & x)
|
||||
math_expression matvecprod(math_expression const & A, T const & x)
|
||||
{
|
||||
int_t M = A.shape()[0];
|
||||
int_t N = A.shape()[1];
|
||||
array_expression::node & A_root = const_cast<array_expression::node &>(A.tree()[A.root()]);
|
||||
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
||||
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
|
||||
while(A_root.lhs.type_family==COMPOSITE_OPERATOR_FAMILY){
|
||||
A_root = A.tree()[A_root.lhs.node_index];
|
||||
@@ -729,7 +813,7 @@ namespace detail
|
||||
}
|
||||
if(A_trans)
|
||||
{
|
||||
array_expression tmp(A, repmat(x, 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), A.context(), A.dtype(), size4(N, M));
|
||||
math_expression tmp(A, repmat(x, 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), A.context(), A.dtype(), size4(N, M));
|
||||
//Remove trans
|
||||
tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs;
|
||||
return sum(tmp, 0);
|
||||
@@ -741,15 +825,15 @@ namespace detail
|
||||
|
||||
}
|
||||
|
||||
array_expression reshape(array const & x, int_t shape0, int_t shape1)
|
||||
{ return array_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_RESHAPE_TYPE), x.context(), x.dtype(), size4(shape0, shape1)); }
|
||||
math_expression reshape(array const & x, int_t shape0, int_t shape1)
|
||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_RESHAPE_TYPE), x.context(), x.dtype(), size4(shape0, shape1)); }
|
||||
|
||||
array_expression reshape(array_expression const & x, int_t shape0, int_t shape1)
|
||||
{ return array_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_RESHAPE_TYPE), x.context(), x.dtype(), size4(shape0, shape1)); }
|
||||
math_expression reshape(math_expression const & x, int_t shape0, int_t shape1)
|
||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_RESHAPE_TYPE), x.context(), x.dtype(), size4(shape0, shape1)); }
|
||||
|
||||
|
||||
#define DEFINE_DOT(LTYPE, RTYPE) \
|
||||
array_expression dot(LTYPE const & x, RTYPE const & y)\
|
||||
math_expression dot(LTYPE const & x, RTYPE const & y)\
|
||||
{\
|
||||
if(x.nshape()<1 || y.nshape()<1){\
|
||||
return x*y;\
|
||||
@@ -771,15 +855,15 @@ array_expression dot(LTYPE const & x, RTYPE const & y)\
|
||||
}
|
||||
|
||||
DEFINE_DOT(array, array)
|
||||
DEFINE_DOT(array_expression, array)
|
||||
DEFINE_DOT(array, array_expression)
|
||||
DEFINE_DOT(array_expression, array_expression)
|
||||
DEFINE_DOT(math_expression, array)
|
||||
DEFINE_DOT(array, math_expression)
|
||||
DEFINE_DOT(math_expression, math_expression)
|
||||
|
||||
#undef DEFINE_DOT
|
||||
|
||||
|
||||
#define DEFINE_NORM(TYPE)\
|
||||
array_expression norm(TYPE const & x, unsigned int order)\
|
||||
math_expression norm(TYPE const & x, unsigned int order)\
|
||||
{\
|
||||
assert(order > 0 && order < 3);\
|
||||
switch(order)\
|
||||
@@ -790,10 +874,24 @@ array_expression norm(TYPE const & x, unsigned int order)\
|
||||
}
|
||||
|
||||
DEFINE_NORM(array)
|
||||
DEFINE_NORM(array_expression)
|
||||
DEFINE_NORM(math_expression)
|
||||
|
||||
#undef DEFINE_NORM
|
||||
|
||||
/*--- Fusion ----*/
|
||||
math_expression fuse(math_expression const & x, math_expression const & y)
|
||||
{
|
||||
assert(x.context()==y.context());
|
||||
return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape());
|
||||
}
|
||||
|
||||
/*--- For loops ---*/
|
||||
ISAACAPI math_expression sfor(math_expression const & start, math_expression const & end, math_expression const & inc, math_expression const & x)
|
||||
{
|
||||
return math_expression(x, make_tuple(x.context(), start, end, inc), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SFOR_TYPE), x.context(), x.dtype(), x.shape());
|
||||
}
|
||||
|
||||
|
||||
/*--- Copy ----*/
|
||||
//---------------------------------------
|
||||
|
||||
|
Reference in New Issue
Block a user