Code quality: renamed math_expression -> expression_tree
This commit is contained in:
300
lib/array.cpp
300
lib/array.cpp
@@ -169,7 +169,7 @@ array_base & array_base::operator=(array_base const & rhs)
|
||||
{
|
||||
if(shape_.min()==0) return *this;
|
||||
assert(dtype_ == rhs.dtype());
|
||||
math_expression expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
expression_tree expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
execute(execution_handler(expression));
|
||||
return *this;
|
||||
}
|
||||
@@ -178,7 +178,7 @@ array_base & array_base::operator=(value_scalar const & rhs)
|
||||
{
|
||||
if(shape_.min()==0) return *this;
|
||||
assert(dtype_ == rhs.dtype());
|
||||
math_expression expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
expression_tree expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
execute(execution_handler(expression));
|
||||
return *this;
|
||||
}
|
||||
@@ -188,12 +188,12 @@ array_base& array_base::operator=(execution_handler const & c)
|
||||
{
|
||||
if(shape_.min()==0) return *this;
|
||||
assert(dtype_ == c.x().dtype());
|
||||
math_expression expression(*this, c.x(), op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
expression_tree expression(*this, c.x(), op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
||||
return *this;
|
||||
}
|
||||
|
||||
array_base & array_base::operator=(math_expression const & rhs)
|
||||
array_base & array_base::operator=(expression_tree const & rhs)
|
||||
{
|
||||
return *this = execution_handler(rhs);
|
||||
}
|
||||
@@ -227,54 +227,54 @@ INSTANTIATE(double);
|
||||
|
||||
|
||||
|
||||
math_expression array_base::operator-()
|
||||
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||
expression_tree array_base::operator-()
|
||||
{ return expression_tree(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
math_expression array_base::operator!()
|
||||
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), context_, INT_TYPE, shape_); }
|
||||
expression_tree array_base::operator!()
|
||||
{ return expression_tree(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), context_, INT_TYPE, shape_); }
|
||||
|
||||
//
|
||||
array_base & array_base::operator+=(value_scalar const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator+=(array_base const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator+=(math_expression const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), rhs.context(), dtype_, shape_); }
|
||||
array_base & array_base::operator+=(expression_tree const & rhs)
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), rhs.context(), dtype_, shape_); }
|
||||
//----
|
||||
array_base & array_base::operator-=(value_scalar const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator-=(array_base const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator-=(math_expression const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), rhs.context(), dtype_, shape_); }
|
||||
array_base & array_base::operator-=(expression_tree const & rhs)
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), rhs.context(), dtype_, shape_); }
|
||||
//----
|
||||
array_base & array_base::operator*=(value_scalar const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator*=(array_base const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator*=(math_expression const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), rhs.context(), dtype_, shape_); }
|
||||
array_base & array_base::operator*=(expression_tree const & rhs)
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), rhs.context(), dtype_, shape_); }
|
||||
//----
|
||||
array_base & array_base::operator/=(value_scalar const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator/=(array_base const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator/=(math_expression const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), rhs.context(), dtype_, shape_); }
|
||||
array_base & array_base::operator/=(expression_tree const & rhs)
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), rhs.context(), dtype_, shape_); }
|
||||
|
||||
/*--- Indexing operators -----*/
|
||||
//---------------------------------------
|
||||
math_expression array_base::operator[](for_idx_t idx) const
|
||||
expression_tree array_base::operator[](for_idx_t idx) const
|
||||
{
|
||||
return math_expression(*this, idx, op_element(BINARY_TYPE_FAMILY, ACCESS_INDEX_TYPE), context_, dtype_, {1});
|
||||
return expression_tree(*this, idx, op_element(BINARY_TYPE_FAMILY, ACCESS_INDEX_TYPE), context_, dtype_, {1});
|
||||
}
|
||||
|
||||
scalar array_base::operator [](int_t idx)
|
||||
@@ -324,7 +324,7 @@ view array_base::operator()(slice const & si, slice const & sj)
|
||||
//---------------------------------------
|
||||
/*--- array ---*/
|
||||
|
||||
array::array(math_expression const & proxy) : array_base(execution_handler(proxy)) {}
|
||||
array::array(expression_tree const & proxy) : array_base(execution_handler(proxy)) {}
|
||||
|
||||
array::array(array_base const & other): array_base(other.dtype(), other.shape(), other.context())
|
||||
{ *this = other; }
|
||||
@@ -379,7 +379,7 @@ scalar::scalar(value_scalar value, driver::Context const & context) : array_base
|
||||
scalar::scalar(numeric_type dtype, driver::Context const & context) : array_base(1, dtype, context)
|
||||
{ }
|
||||
|
||||
scalar::scalar(math_expression const & proxy) : array_base(proxy){ }
|
||||
scalar::scalar(expression_tree const & proxy) : array_base(proxy){ }
|
||||
|
||||
void scalar::inject(values_holder & v) const
|
||||
{
|
||||
@@ -511,53 +511,53 @@ shape_t broadcast(shape_t const & a, shape_t const & b)
|
||||
}
|
||||
|
||||
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \
|
||||
math_expression OPNAME (array_base const & x, math_expression const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||
expression_tree OPNAME (array_base const & x, expression_tree const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||
\
|
||||
math_expression OPNAME (array_base const & x, array_base const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); }\
|
||||
expression_tree OPNAME (array_base const & x, array_base const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); }\
|
||||
\
|
||||
math_expression OPNAME (array_base const & x, value_scalar const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
expression_tree OPNAME (array_base const & x, value_scalar const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
\
|
||||
math_expression OPNAME (array_base const & x, for_idx_t const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
expression_tree OPNAME (array_base const & x, for_idx_t const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
\
|
||||
\
|
||||
math_expression OPNAME (math_expression const & x, math_expression const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||
expression_tree OPNAME (expression_tree const & x, expression_tree const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||
\
|
||||
math_expression OPNAME (math_expression const & x, array_base const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||
expression_tree OPNAME (expression_tree const & x, array_base const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||
\
|
||||
math_expression OPNAME (math_expression const & x, value_scalar const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
expression_tree OPNAME (expression_tree const & x, value_scalar const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
\
|
||||
math_expression OPNAME (math_expression const & x, for_idx_t const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
expression_tree OPNAME (expression_tree const & x, for_idx_t const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
\
|
||||
\
|
||||
math_expression OPNAME (value_scalar const & y, math_expression const & x) \
|
||||
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
expression_tree OPNAME (value_scalar const & y, expression_tree const & x) \
|
||||
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
\
|
||||
math_expression OPNAME (value_scalar const & y, array_base const & x) \
|
||||
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
expression_tree OPNAME (value_scalar const & y, array_base const & x) \
|
||||
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
\
|
||||
math_expression OPNAME (value_scalar const & x, for_idx_t const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); }\
|
||||
expression_tree OPNAME (value_scalar const & x, for_idx_t const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); }\
|
||||
\
|
||||
\
|
||||
math_expression OPNAME (for_idx_t const & y, math_expression const & x) \
|
||||
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
expression_tree OPNAME (for_idx_t const & y, expression_tree const & x) \
|
||||
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
\
|
||||
math_expression OPNAME (for_idx_t const & y, value_scalar const & x) \
|
||||
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); } \
|
||||
expression_tree OPNAME (for_idx_t const & y, value_scalar const & x) \
|
||||
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); } \
|
||||
\
|
||||
math_expression OPNAME (for_idx_t const & y, array_base const & x) \
|
||||
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
expression_tree OPNAME (for_idx_t const & y, array_base const & x) \
|
||||
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
\
|
||||
math_expression OPNAME (for_idx_t const & y, for_idx_t const & x) \
|
||||
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP)); }
|
||||
expression_tree OPNAME (for_idx_t const & y, for_idx_t const & x) \
|
||||
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP)); }
|
||||
|
||||
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(ADD_TYPE, operator +, x.dtype())
|
||||
@@ -580,39 +580,39 @@ DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_EQ_TYPE, operator ==, INT_TYPE)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_NEQ_TYPE, operator !=, INT_TYPE)
|
||||
|
||||
#define DEFINE_OUTER(LTYPE, RTYPE) \
|
||||
math_expression outer(LTYPE const & x, RTYPE const & y)\
|
||||
expression_tree outer(LTYPE const & x, RTYPE const & y)\
|
||||
{\
|
||||
assert(x.dim()<=1 && y.dim()<=1);\
|
||||
if(x.dim()<1 || y.dim()<1)\
|
||||
return x*y;\
|
||||
return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OUTER_PROD_TYPE), x.context(), x.dtype(), {x.shape().max(), y.shape().max()} );\
|
||||
return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OUTER_PROD_TYPE), x.context(), x.dtype(), {x.shape().max(), y.shape().max()} );\
|
||||
}\
|
||||
|
||||
DEFINE_OUTER(array_base, array_base)
|
||||
DEFINE_OUTER(math_expression, array_base)
|
||||
DEFINE_OUTER(array_base, math_expression)
|
||||
DEFINE_OUTER(math_expression, math_expression)
|
||||
DEFINE_OUTER(expression_tree, array_base)
|
||||
DEFINE_OUTER(array_base, expression_tree)
|
||||
DEFINE_OUTER(expression_tree, expression_tree)
|
||||
|
||||
#undef DEFINE_ELEMENT_BINARY_OPERATOR
|
||||
|
||||
#define DEFINE_ROT(LTYPE, RTYPE, CTYPE, STYPE)\
|
||||
math_expression rot(LTYPE const & x, RTYPE const & y, CTYPE const & c, STYPE const & s)\
|
||||
expression_tree rot(LTYPE const & x, RTYPE const & y, CTYPE const & c, STYPE const & s)\
|
||||
{ return fuse(assign(x, c*x + s*y), assign(y, c*y - s*x)); }
|
||||
|
||||
DEFINE_ROT(array_base, array_base, scalar, scalar)
|
||||
DEFINE_ROT(math_expression, array_base, scalar, scalar)
|
||||
DEFINE_ROT(array_base, math_expression, scalar, scalar)
|
||||
DEFINE_ROT(math_expression, math_expression, scalar, scalar)
|
||||
DEFINE_ROT(expression_tree, array_base, scalar, scalar)
|
||||
DEFINE_ROT(array_base, expression_tree, scalar, scalar)
|
||||
DEFINE_ROT(expression_tree, expression_tree, scalar, scalar)
|
||||
|
||||
DEFINE_ROT(array_base, array_base, value_scalar, value_scalar)
|
||||
DEFINE_ROT(math_expression, array_base, value_scalar, value_scalar)
|
||||
DEFINE_ROT(array_base, math_expression, value_scalar, value_scalar)
|
||||
DEFINE_ROT(math_expression, math_expression, value_scalar, value_scalar)
|
||||
DEFINE_ROT(expression_tree, array_base, value_scalar, value_scalar)
|
||||
DEFINE_ROT(array_base, expression_tree, value_scalar, value_scalar)
|
||||
DEFINE_ROT(expression_tree, expression_tree, value_scalar, value_scalar)
|
||||
|
||||
DEFINE_ROT(array_base, array_base, math_expression, math_expression)
|
||||
DEFINE_ROT(math_expression, array_base, math_expression, math_expression)
|
||||
DEFINE_ROT(array_base, math_expression, math_expression, math_expression)
|
||||
DEFINE_ROT(math_expression, math_expression, math_expression, math_expression)
|
||||
DEFINE_ROT(array_base, array_base, expression_tree, expression_tree)
|
||||
DEFINE_ROT(expression_tree, array_base, expression_tree, expression_tree)
|
||||
DEFINE_ROT(array_base, expression_tree, expression_tree, expression_tree)
|
||||
DEFINE_ROT(expression_tree, expression_tree, expression_tree, expression_tree)
|
||||
|
||||
|
||||
|
||||
@@ -621,11 +621,11 @@ DEFINE_ROT(math_expression, math_expression, math_expression, math_expression)
|
||||
/*--- Math Operators----*/
|
||||
//---------------------------------------
|
||||
#define DEFINE_ELEMENT_UNARY_OPERATOR(OP, OPNAME) \
|
||||
math_expression OPNAME (array_base const & x) \
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
|
||||
expression_tree OPNAME (array_base const & x) \
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
|
||||
\
|
||||
math_expression OPNAME (math_expression const & x) \
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
|
||||
expression_tree OPNAME (expression_tree const & x) \
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
|
||||
|
||||
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?FABS_TYPE:ABS_TYPE, abs)
|
||||
DEFINE_ELEMENT_UNARY_OPERATOR(ACOS_TYPE, acos)
|
||||
@@ -669,14 +669,14 @@ inline operation_type casted(numeric_type dtype)
|
||||
}
|
||||
}
|
||||
|
||||
math_expression cast(array_base const & x, numeric_type dtype)
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||
expression_tree cast(array_base const & x, numeric_type dtype)
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||
|
||||
math_expression cast(math_expression const & x, numeric_type dtype)
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||
expression_tree cast(expression_tree const & x, numeric_type dtype)
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||
|
||||
isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||
{ return math_expression(value_scalar(1), value_scalar(0), op_element(UNARY_TYPE_FAMILY, VDIAG_TYPE), ctx, dtype, {M, N}); }
|
||||
isaac::expression_tree eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||
{ return expression_tree(value_scalar(1), value_scalar(0), op_element(UNARY_TYPE_FAMILY, VDIAG_TYPE), ctx, dtype, {M, N}); }
|
||||
|
||||
array diag(array_base & x, int offset)
|
||||
{
|
||||
@@ -688,8 +688,8 @@ array diag(array_base & x, int offset)
|
||||
}
|
||||
|
||||
|
||||
isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(UNARY_TYPE_FAMILY, ADD_TYPE), ctx, dtype, {M, N}); }
|
||||
isaac::expression_tree zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||
{ return expression_tree(value_scalar(0, dtype), invalid_node(), op_element(UNARY_TYPE_FAMILY, ADD_TYPE), ctx, dtype, {M, N}); }
|
||||
|
||||
inline shape_t flip(shape_t const & shape)
|
||||
{
|
||||
@@ -702,77 +702,77 @@ inline shape_t flip(shape_t const & shape)
|
||||
//inline size4 prod(size4 const & shape1, size4 const & shape2)
|
||||
//{ return size4(shape1[0]*shape2[0], shape1[1]*shape2[1]);}
|
||||
|
||||
math_expression trans(array_base const & x) \
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
|
||||
expression_tree trans(array_base const & x) \
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
|
||||
\
|
||||
math_expression trans(math_expression const & x) \
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
|
||||
expression_tree trans(expression_tree const & x) \
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
|
||||
|
||||
math_expression repmat(array_base const & A, int_t const & rep1, int_t const & rep2)
|
||||
expression_tree repmat(array_base const & A, int_t const & rep1, int_t const & rep2)
|
||||
{
|
||||
int_t sub1 = A.shape()[0];
|
||||
int_t sub2 = A.dim()==2?A.shape()[1]:1;
|
||||
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
|
||||
return expression_tree(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
|
||||
}
|
||||
|
||||
math_expression repmat(math_expression const & A, int_t const & rep1, int_t const & rep2)
|
||||
expression_tree repmat(expression_tree const & A, int_t const & rep1, int_t const & rep2)
|
||||
{
|
||||
int_t sub1 = A.shape()[0];
|
||||
int_t sub2 = A.dim()==2?A.shape()[1]:1;
|
||||
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
|
||||
return expression_tree(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
|
||||
}
|
||||
|
||||
#define DEFINE_ACCESS_ROW(TYPEA, TYPEB) \
|
||||
math_expression row(TYPEA const & x, TYPEB const & i)\
|
||||
{ return math_expression(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_ROW_TYPE), x.context(), x.dtype(), {x.shape()[1]}); }
|
||||
expression_tree row(TYPEA const & x, TYPEB const & i)\
|
||||
{ return expression_tree(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_ROW_TYPE), x.context(), x.dtype(), {x.shape()[1]}); }
|
||||
|
||||
DEFINE_ACCESS_ROW(array_base, value_scalar)
|
||||
DEFINE_ACCESS_ROW(array_base, for_idx_t)
|
||||
DEFINE_ACCESS_ROW(array_base, math_expression)
|
||||
DEFINE_ACCESS_ROW(array_base, expression_tree)
|
||||
|
||||
DEFINE_ACCESS_ROW(math_expression, value_scalar)
|
||||
DEFINE_ACCESS_ROW(math_expression, for_idx_t)
|
||||
DEFINE_ACCESS_ROW(math_expression, math_expression)
|
||||
DEFINE_ACCESS_ROW(expression_tree, value_scalar)
|
||||
DEFINE_ACCESS_ROW(expression_tree, for_idx_t)
|
||||
DEFINE_ACCESS_ROW(expression_tree, expression_tree)
|
||||
|
||||
#define DEFINE_ACCESS_COL(TYPEA, TYPEB) \
|
||||
math_expression col(TYPEA const & x, TYPEB const & i)\
|
||||
{ return math_expression(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_COLUMN_TYPE), x.context(), x.dtype(), {x.shape()[0]}); }
|
||||
expression_tree col(TYPEA const & x, TYPEB const & i)\
|
||||
{ return expression_tree(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_COLUMN_TYPE), x.context(), x.dtype(), {x.shape()[0]}); }
|
||||
|
||||
DEFINE_ACCESS_COL(array_base, value_scalar)
|
||||
DEFINE_ACCESS_COL(array_base, for_idx_t)
|
||||
DEFINE_ACCESS_COL(array_base, math_expression)
|
||||
DEFINE_ACCESS_COL(array_base, expression_tree)
|
||||
|
||||
DEFINE_ACCESS_COL(math_expression, value_scalar)
|
||||
DEFINE_ACCESS_COL(math_expression, for_idx_t)
|
||||
DEFINE_ACCESS_COL(math_expression, math_expression)
|
||||
DEFINE_ACCESS_COL(expression_tree, value_scalar)
|
||||
DEFINE_ACCESS_COL(expression_tree, for_idx_t)
|
||||
DEFINE_ACCESS_COL(expression_tree, expression_tree)
|
||||
|
||||
////---------------------------------------
|
||||
|
||||
///*--- Reductions ---*/
|
||||
////---------------------------------------
|
||||
#define DEFINE_REDUCTION(OP, OPNAME)\
|
||||
math_expression OPNAME(array_base const & x, int_t axis)\
|
||||
expression_tree OPNAME(array_base const & x, int_t axis)\
|
||||
{\
|
||||
if(axis < -1 || axis > x.dim())\
|
||||
throw std::out_of_range("The axis entry is out of bounds");\
|
||||
else if(axis==-1)\
|
||||
return math_expression(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
|
||||
return expression_tree(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
|
||||
else if(axis==0)\
|
||||
return math_expression(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
|
||||
return expression_tree(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
|
||||
else\
|
||||
return math_expression(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
|
||||
return expression_tree(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
|
||||
}\
|
||||
\
|
||||
math_expression OPNAME(math_expression const & x, int_t axis)\
|
||||
expression_tree OPNAME(expression_tree const & x, int_t axis)\
|
||||
{\
|
||||
if(axis < -1 || axis > x.dim())\
|
||||
throw std::out_of_range("The axis entry is out of bounds");\
|
||||
if(axis==-1)\
|
||||
return math_expression(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
|
||||
return expression_tree(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
|
||||
else if(axis==0)\
|
||||
return math_expression(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
|
||||
return expression_tree(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
|
||||
else\
|
||||
return math_expression(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
|
||||
return expression_tree(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
|
||||
}
|
||||
|
||||
DEFINE_REDUCTION(ADD_TYPE, sum)
|
||||
@@ -786,51 +786,51 @@ DEFINE_REDUCTION(ELEMENT_ARGMIN_TYPE, argmin)
|
||||
namespace detail
|
||||
{
|
||||
|
||||
math_expression matmatprod(array_base const & A, array_base const & B)
|
||||
expression_tree matmatprod(array_base const & A, array_base const & B)
|
||||
{
|
||||
shape_t shape{A.shape()[0], B.shape()[1]};
|
||||
return math_expression(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, MATRIX_PRODUCT_NN_TYPE), A.context(), A.dtype(), shape);
|
||||
return expression_tree(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, MATRIX_PRODUCT_NN_TYPE), A.context(), A.dtype(), shape);
|
||||
}
|
||||
|
||||
math_expression matmatprod(math_expression const & A, array_base const & B)
|
||||
expression_tree matmatprod(expression_tree const & A, array_base const & B)
|
||||
{
|
||||
operation_type type = MATRIX_PRODUCT_NN_TYPE;
|
||||
shape_t shape{A.shape()[0], B.shape()[1]};
|
||||
|
||||
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
||||
expression_tree::node & A_root = const_cast<expression_tree::node &>(A.tree()[A.root()]);
|
||||
bool A_trans = A_root.op.type==TRANS_TYPE;
|
||||
if(A_trans){
|
||||
type = MATRIX_PRODUCT_TN_TYPE;
|
||||
}
|
||||
|
||||
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
||||
expression_tree res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
expression_tree::node & res_root = const_cast<expression_tree::node &>(res.tree()[res.root()]);
|
||||
if(A_trans) res_root.lhs = A_root.lhs;
|
||||
return res;
|
||||
}
|
||||
|
||||
math_expression matmatprod(array_base const & A, math_expression const & B)
|
||||
expression_tree matmatprod(array_base const & A, expression_tree const & B)
|
||||
{
|
||||
operation_type type = MATRIX_PRODUCT_NN_TYPE;
|
||||
shape_t shape{A.shape()[0], B.shape()[1]};
|
||||
|
||||
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
|
||||
expression_tree::node & B_root = const_cast<expression_tree::node &>(B.tree()[B.root()]);
|
||||
bool B_trans = B_root.op.type==TRANS_TYPE;
|
||||
if(B_trans){
|
||||
type = MATRIX_PRODUCT_NT_TYPE;
|
||||
}
|
||||
|
||||
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
||||
expression_tree res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
expression_tree::node & res_root = const_cast<expression_tree::node &>(res.tree()[res.root()]);
|
||||
if(B_trans) res_root.rhs = B_root.lhs;
|
||||
return res;
|
||||
}
|
||||
|
||||
math_expression matmatprod(math_expression const & A, math_expression const & B)
|
||||
expression_tree matmatprod(expression_tree const & A, expression_tree const & B)
|
||||
{
|
||||
operation_type type = MATRIX_PRODUCT_NN_TYPE;
|
||||
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
||||
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
|
||||
expression_tree::node & A_root = const_cast<expression_tree::node &>(A.tree()[A.root()]);
|
||||
expression_tree::node & B_root = const_cast<expression_tree::node &>(B.tree()[B.root()]);
|
||||
shape_t shape{A.shape()[0], B.shape()[1]};
|
||||
|
||||
bool A_trans = A_root.op.type==TRANS_TYPE;
|
||||
@@ -841,15 +841,15 @@ namespace detail
|
||||
else if(!A_trans && B_trans) type = MATRIX_PRODUCT_NT_TYPE;
|
||||
else type = MATRIX_PRODUCT_NN_TYPE;
|
||||
|
||||
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
||||
expression_tree res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
expression_tree::node & res_root = const_cast<expression_tree::node &>(res.tree()[res.root()]);
|
||||
if(A_trans) res_root.lhs = A_root.lhs;
|
||||
if(B_trans) res_root.rhs = B_root.lhs;
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
math_expression matvecprod(array_base const & A, T const & x)
|
||||
expression_tree matvecprod(array_base const & A, T const & x)
|
||||
{
|
||||
int_t M = A.shape()[0];
|
||||
int_t N = A.shape()[1];
|
||||
@@ -857,11 +857,11 @@ namespace detail
|
||||
}
|
||||
|
||||
template<class T>
|
||||
math_expression matvecprod(math_expression const & A, T const & x)
|
||||
expression_tree matvecprod(expression_tree const & A, T const & x)
|
||||
{
|
||||
int_t M = A.shape()[0];
|
||||
int_t N = A.shape()[1];
|
||||
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
||||
expression_tree::node & A_root = const_cast<expression_tree::node &>(A.tree()[A.root()]);
|
||||
bool A_trans = A_root.op.type==TRANS_TYPE;
|
||||
while(A_root.lhs.subtype==COMPOSITE_OPERATOR_TYPE){
|
||||
A_root = A.tree()[A_root.lhs.node_index];
|
||||
@@ -869,7 +869,7 @@ namespace detail
|
||||
}
|
||||
if(A_trans)
|
||||
{
|
||||
math_expression tmp(A, repmat(x, 1, M), op_element(BINARY_TYPE_FAMILY, ELEMENT_PROD_TYPE), A.context(), A.dtype(), {N, M});
|
||||
expression_tree tmp(A, repmat(x, 1, M), op_element(BINARY_TYPE_FAMILY, ELEMENT_PROD_TYPE), A.context(), A.dtype(), {N, M});
|
||||
//Remove trans
|
||||
tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs;
|
||||
return sum(tmp, 0);
|
||||
@@ -889,17 +889,17 @@ ISAACAPI void swap(view x, view y)
|
||||
}
|
||||
|
||||
//Reshape
|
||||
math_expression reshape(array_base const & x, shape_t const & shape)
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
|
||||
expression_tree reshape(array_base const & x, shape_t const & shape)
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
|
||||
|
||||
math_expression reshape(math_expression const & x, shape_t const & shape)
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
|
||||
expression_tree reshape(expression_tree const & x, shape_t const & shape)
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
|
||||
|
||||
math_expression ravel(array_base const & x)
|
||||
expression_tree ravel(array_base const & x)
|
||||
{ return reshape(x, {x.shape().prod()}); }
|
||||
|
||||
#define DEFINE_DOT(LTYPE, RTYPE) \
|
||||
math_expression dot(LTYPE const & x, RTYPE const & y)\
|
||||
expression_tree dot(LTYPE const & x, RTYPE const & y)\
|
||||
{\
|
||||
numeric_type dtype = x.dtype();\
|
||||
driver::Context const & context = x.context();\
|
||||
@@ -908,7 +908,7 @@ math_expression dot(LTYPE const & x, RTYPE const & y)\
|
||||
if(x.dim()==2 && x.shape()[1]==0)\
|
||||
return zeros(x.shape()[0], y.shape()[1], dtype, context);\
|
||||
if(x.shape()[0]==0 || (y.dim()==2 && y.shape()[1]==0))\
|
||||
return math_expression(invalid_node(), invalid_node(), op_element(UNARY_TYPE_FAMILY, INVALID_TYPE), context, dtype, {0});\
|
||||
return expression_tree(invalid_node(), invalid_node(), op_element(UNARY_TYPE_FAMILY, INVALID_TYPE), context, dtype, {0});\
|
||||
if(x.dim()==1 && y.dim()==1)\
|
||||
return sum(x*y);\
|
||||
if(x.dim()==2 && x.shape()[0]==1 && y.dim()==1){\
|
||||
@@ -940,15 +940,15 @@ math_expression dot(LTYPE const & x, RTYPE const & y)\
|
||||
}
|
||||
|
||||
DEFINE_DOT(array_base, array_base)
|
||||
DEFINE_DOT(math_expression, array_base)
|
||||
DEFINE_DOT(array_base, math_expression)
|
||||
DEFINE_DOT(math_expression, math_expression)
|
||||
DEFINE_DOT(expression_tree, array_base)
|
||||
DEFINE_DOT(array_base, expression_tree)
|
||||
DEFINE_DOT(expression_tree, expression_tree)
|
||||
|
||||
#undef DEFINE_DOT
|
||||
|
||||
|
||||
#define DEFINE_NORM(TYPE)\
|
||||
math_expression norm(TYPE const & x, unsigned int order)\
|
||||
expression_tree norm(TYPE const & x, unsigned int order)\
|
||||
{\
|
||||
assert(order > 0 && order < 3);\
|
||||
switch(order)\
|
||||
@@ -959,21 +959,21 @@ math_expression norm(TYPE const & x, unsigned int order)\
|
||||
}
|
||||
|
||||
DEFINE_NORM(array_base)
|
||||
DEFINE_NORM(math_expression)
|
||||
DEFINE_NORM(expression_tree)
|
||||
|
||||
#undef DEFINE_NORM
|
||||
|
||||
/*--- Fusion ----*/
|
||||
math_expression fuse(math_expression const & x, math_expression const & y)
|
||||
expression_tree fuse(expression_tree const & x, expression_tree const & y)
|
||||
{
|
||||
assert(x.context()==y.context());
|
||||
return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape());
|
||||
return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape());
|
||||
}
|
||||
|
||||
/*--- For loops ---*/
|
||||
ISAACAPI math_expression sfor(math_expression const & start, math_expression const & end, math_expression const & inc, math_expression const & x)
|
||||
ISAACAPI expression_tree sfor(expression_tree const & start, expression_tree const & end, expression_tree const & inc, expression_tree const & x)
|
||||
{
|
||||
return math_expression(x, make_tuple(x.context(), start, end, inc), op_element(UNARY_TYPE_FAMILY, SFOR_TYPE), x.context(), x.dtype(), x.shape());
|
||||
return expression_tree(x, make_tuple(x.context(), start, end, inc), op_element(UNARY_TYPE_FAMILY, SFOR_TYPE), x.context(), x.dtype(), x.shape());
|
||||
}
|
||||
|
||||
|
||||
@@ -1160,7 +1160,7 @@ std::ostream& operator<<(std::ostream & os, array_base const & a)
|
||||
return os;
|
||||
}
|
||||
|
||||
ISAACAPI std::ostream& operator<<(std::ostream & oss, math_expression const & expression)
|
||||
ISAACAPI std::ostream& operator<<(std::ostream & oss, expression_tree const & expression)
|
||||
{
|
||||
return oss << array(expression);
|
||||
}
|
||||
|
Reference in New Issue
Block a user