Code Quality: heavy renaming and cleaning

This commit is contained in:
Philippe Tillet
2015-12-19 02:04:39 -05:00
parent b6d596d26d
commit bfa7504fc0
22 changed files with 502 additions and 518 deletions

View File

@@ -169,7 +169,7 @@ array_base & array_base::operator=(array_base const & rhs)
{
if(shape_.min()==0) return *this;
assert(dtype_ == rhs.dtype());
math_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
math_expression expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
execute(execution_handler(expression));
return *this;
}
@@ -178,7 +178,7 @@ array_base & array_base::operator=(value_scalar const & rhs)
{
if(shape_.min()==0) return *this;
assert(dtype_ == rhs.dtype());
math_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
math_expression expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
execute(execution_handler(expression));
return *this;
}
@@ -188,7 +188,7 @@ array_base& array_base::operator=(execution_handler const & c)
{
if(shape_.min()==0) return *this;
assert(dtype_ == c.x().dtype());
math_expression expression(*this, c.x(), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
math_expression expression(*this, c.x(), op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
return *this;
}
@@ -228,53 +228,53 @@ INSTANTIATE(double);
math_expression array_base::operator-()
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
math_expression array_base::operator!()
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_NEGATE_TYPE), context_, INT_TYPE, shape_); }
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), context_, INT_TYPE, shape_); }
//
array_base & array_base::operator+=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator+=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator+=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), rhs.context(), dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), rhs.context(), dtype_, shape_); }
//----
array_base & array_base::operator-=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator-=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator-=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), rhs.context(), dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), rhs.context(), dtype_, shape_); }
//----
array_base & array_base::operator*=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator*=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator*=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), rhs.context(), dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), rhs.context(), dtype_, shape_); }
//----
array_base & array_base::operator/=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator/=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator/=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), rhs.context(), dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), rhs.context(), dtype_, shape_); }
/*--- Indexing operators -----*/
//---------------------------------------
math_expression array_base::operator[](for_idx_t idx) const
{
return math_expression(*this, idx, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ACCESS_INDEX_TYPE), context_, dtype_, {1});
return math_expression(*this, idx, op_element(BINARY_TYPE_FAMILY, ACCESS_INDEX_TYPE), context_, dtype_, {1});
}
scalar array_base::operator [](int_t idx)
@@ -512,72 +512,72 @@ shape_t broadcast(shape_t const & a, shape_t const & b)
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \
math_expression OPNAME (array_base const & x, math_expression const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
\
math_expression OPNAME (array_base const & x, array_base const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); }\
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); }\
\
math_expression OPNAME (array_base const & x, value_scalar const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\
math_expression OPNAME (array_base const & x, for_idx_t const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\
\
math_expression OPNAME (math_expression const & x, math_expression const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
\
math_expression OPNAME (math_expression const & x, array_base const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
\
math_expression OPNAME (math_expression const & x, value_scalar const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\
math_expression OPNAME (math_expression const & x, for_idx_t const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\
\
math_expression OPNAME (value_scalar const & y, math_expression const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\
math_expression OPNAME (value_scalar const & y, array_base const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\
math_expression OPNAME (value_scalar const & x, for_idx_t const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE); }\
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); }\
\
\
math_expression OPNAME (for_idx_t const & y, math_expression const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\
math_expression OPNAME (for_idx_t const & y, value_scalar const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE); } \
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); } \
\
math_expression OPNAME (for_idx_t const & y, array_base const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\
math_expression OPNAME (for_idx_t const & y, for_idx_t const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP)); }
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP)); }
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ADD_TYPE, operator +, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_SUB_TYPE, operator -, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_MULT_TYPE, operator *, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_DIV_TYPE, operator /, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(ADD_TYPE, operator +, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(SUB_TYPE, operator -, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(MULT_TYPE, operator *, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(DIV_TYPE, operator /, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, maximum, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, minimum, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_MAX_TYPE, maximum, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_MIN_TYPE, minimum, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_POW_TYPE, pow, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ASSIGN_TYPE, assign, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(ASSIGN_TYPE, assign, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GREATER_TYPE, operator >, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GEQ_TYPE, operator >=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LESS_TYPE, operator <, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LEQ_TYPE, operator <=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_GREATER_TYPE, operator >, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_GEQ_TYPE, operator >=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_LESS_TYPE, operator <, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_LEQ_TYPE, operator <=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_EQ_TYPE, operator ==, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_NEQ_TYPE, operator !=, INT_TYPE)
#define DEFINE_OUTER(LTYPE, RTYPE) \
math_expression outer(LTYPE const & x, RTYPE const & y)\
@@ -585,7 +585,7 @@ math_expression outer(LTYPE const & x, RTYPE const & y)\
assert(x.dim()<=1 && y.dim()<=1);\
if(x.dim()<1 || y.dim()<1)\
return x*y;\
return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_OUTER_PROD_TYPE), x.context(), x.dtype(), {x.shape().max(), y.shape().max()} );\
return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OUTER_PROD_TYPE), x.context(), x.dtype(), {x.shape().max(), y.shape().max()} );\
}\
DEFINE_OUTER(array_base, array_base)
@@ -622,61 +622,61 @@ DEFINE_ROT(math_expression, math_expression, math_expression, math_expression)
//---------------------------------------
#define DEFINE_ELEMENT_UNARY_OPERATOR(OP, OPNAME) \
math_expression OPNAME (array_base const & x) \
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
\
math_expression OPNAME (math_expression const & x) \
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?OPERATOR_FABS_TYPE:OPERATOR_ABS_TYPE, abs)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ACOS_TYPE, acos)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ASIN_TYPE, asin)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ATAN_TYPE, atan)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_CEIL_TYPE, ceil)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_COS_TYPE, cos)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_COSH_TYPE, cosh)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_EXP_TYPE, exp)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_FLOOR_TYPE, floor)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_LOG_TYPE, log)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_LOG10_TYPE,log10)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_SIN_TYPE, sin)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_SINH_TYPE, sinh)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_SQRT_TYPE, sqrt)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_TAN_TYPE, tan)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_TANH_TYPE, tanh)
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?FABS_TYPE:ABS_TYPE, abs)
DEFINE_ELEMENT_UNARY_OPERATOR(ACOS_TYPE, acos)
DEFINE_ELEMENT_UNARY_OPERATOR(ASIN_TYPE, asin)
DEFINE_ELEMENT_UNARY_OPERATOR(ATAN_TYPE, atan)
DEFINE_ELEMENT_UNARY_OPERATOR(CEIL_TYPE, ceil)
DEFINE_ELEMENT_UNARY_OPERATOR(COS_TYPE, cos)
DEFINE_ELEMENT_UNARY_OPERATOR(COSH_TYPE, cosh)
DEFINE_ELEMENT_UNARY_OPERATOR(EXP_TYPE, exp)
DEFINE_ELEMENT_UNARY_OPERATOR(FLOOR_TYPE, floor)
DEFINE_ELEMENT_UNARY_OPERATOR(LOG_TYPE, log)
DEFINE_ELEMENT_UNARY_OPERATOR(LOG10_TYPE,log10)
DEFINE_ELEMENT_UNARY_OPERATOR(SIN_TYPE, sin)
DEFINE_ELEMENT_UNARY_OPERATOR(SINH_TYPE, sinh)
DEFINE_ELEMENT_UNARY_OPERATOR(SQRT_TYPE, sqrt)
DEFINE_ELEMENT_UNARY_OPERATOR(TAN_TYPE, tan)
DEFINE_ELEMENT_UNARY_OPERATOR(TANH_TYPE, tanh)
#undef DEFINE_ELEMENT_UNARY_OPERATOR
//---------------------------------------
///*--- Misc----*/
////---------------------------------------
inline operation_node_type casted(numeric_type dtype)
inline operation_type casted(numeric_type dtype)
{
switch(dtype)
{
// case BOOL_TYPE: return OPERATOR_CAST_BOOL_TYPE;
case CHAR_TYPE: return OPERATOR_CAST_CHAR_TYPE;
case UCHAR_TYPE: return OPERATOR_CAST_UCHAR_TYPE;
case SHORT_TYPE: return OPERATOR_CAST_SHORT_TYPE;
case USHORT_TYPE: return OPERATOR_CAST_USHORT_TYPE;
case INT_TYPE: return OPERATOR_CAST_INT_TYPE;
case UINT_TYPE: return OPERATOR_CAST_UINT_TYPE;
case LONG_TYPE: return OPERATOR_CAST_LONG_TYPE;
case ULONG_TYPE: return OPERATOR_CAST_ULONG_TYPE;
// case FLOAT_TYPE: return OPERATOR_CAST_HALF_TYPE;
case FLOAT_TYPE: return OPERATOR_CAST_FLOAT_TYPE;
case DOUBLE_TYPE: return OPERATOR_CAST_DOUBLE_TYPE;
// case BOOL_TYPE: return CAST_BOOL_TYPE;
case CHAR_TYPE: return CAST_CHAR_TYPE;
case UCHAR_TYPE: return CAST_UCHAR_TYPE;
case SHORT_TYPE: return CAST_SHORT_TYPE;
case USHORT_TYPE: return CAST_USHORT_TYPE;
case INT_TYPE: return CAST_INT_TYPE;
case UINT_TYPE: return CAST_UINT_TYPE;
case LONG_TYPE: return CAST_LONG_TYPE;
case ULONG_TYPE: return CAST_ULONG_TYPE;
// case FLOAT_TYPE: return CAST_HALF_TYPE;
case FLOAT_TYPE: return CAST_FLOAT_TYPE;
case DOUBLE_TYPE: return CAST_DOUBLE_TYPE;
default: throw unknown_datatype(dtype);
}
}
math_expression cast(array_base const & x, numeric_type dtype)
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
math_expression cast(math_expression const & x, numeric_type dtype)
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return math_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, {M, N}); }
{ return math_expression(value_scalar(1), value_scalar(0), op_element(UNARY_TYPE_FAMILY, VDIAG_TYPE), ctx, dtype, {M, N}); }
array diag(array_base & x, int offset)
{
@@ -689,7 +689,7 @@ array diag(array_base & x, int offset)
isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, {M, N}); }
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(UNARY_TYPE_FAMILY, ADD_TYPE), ctx, dtype, {M, N}); }
inline shape_t flip(shape_t const & shape)
{
@@ -703,28 +703,28 @@ inline shape_t flip(shape_t const & shape)
//{ return size4(shape1[0]*shape2[0], shape1[1]*shape2[1]);}
math_expression trans(array_base const & x) \
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
\
math_expression trans(math_expression const & x) \
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
math_expression repmat(array_base const & A, int_t const & rep1, int_t const & rep2)
{
int_t sub1 = A.shape()[0];
int_t sub2 = A.dim()==2?A.shape()[1]:1;
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
}
math_expression repmat(math_expression const & A, int_t const & rep1, int_t const & rep2)
{
int_t sub1 = A.shape()[0];
int_t sub2 = A.dim()==2?A.shape()[1]:1;
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
}
#define DEFINE_ACCESS_ROW(TYPEA, TYPEB) \
math_expression row(TYPEA const & x, TYPEB const & i)\
{ return math_expression(x, i, op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_MATRIX_ROW_TYPE), x.context(), x.dtype(), {x.shape()[1]}); }
{ return math_expression(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_ROW_TYPE), x.context(), x.dtype(), {x.shape()[1]}); }
DEFINE_ACCESS_ROW(array_base, value_scalar)
DEFINE_ACCESS_ROW(array_base, for_idx_t)
@@ -736,7 +736,7 @@ DEFINE_ACCESS_ROW(math_expression, math_expression)
#define DEFINE_ACCESS_COL(TYPEA, TYPEB) \
math_expression col(TYPEA const & x, TYPEB const & i)\
{ return math_expression(x, i, op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_MATRIX_COLUMN_TYPE), x.context(), x.dtype(), {x.shape()[0]}); }
{ return math_expression(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_COLUMN_TYPE), x.context(), x.dtype(), {x.shape()[0]}); }
DEFINE_ACCESS_COL(array_base, value_scalar)
DEFINE_ACCESS_COL(array_base, for_idx_t)
@@ -756,11 +756,11 @@ math_expression OPNAME(array_base const & x, int_t axis)\
if(axis < -1 || axis > x.dim())\
throw std::out_of_range("The axis entry is out of bounds");\
else if(axis==-1)\
return math_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
return math_expression(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
else if(axis==0)\
return math_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
return math_expression(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
else\
return math_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
return math_expression(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
}\
\
math_expression OPNAME(math_expression const & x, int_t axis)\
@@ -768,18 +768,18 @@ math_expression OPNAME(math_expression const & x, int_t axis)\
if(axis < -1 || axis > x.dim())\
throw std::out_of_range("The axis entry is out of bounds");\
if(axis==-1)\
return math_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
return math_expression(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
else if(axis==0)\
return math_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
return math_expression(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
else\
return math_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
return math_expression(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
}
DEFINE_REDUCTION(OPERATOR_ADD_TYPE, sum)
DEFINE_REDUCTION(OPERATOR_ELEMENT_ARGMAX_TYPE, argmax)
DEFINE_REDUCTION(OPERATOR_ELEMENT_MAX_TYPE, max)
DEFINE_REDUCTION(OPERATOR_ELEMENT_MIN_TYPE, min)
DEFINE_REDUCTION(OPERATOR_ELEMENT_ARGMIN_TYPE, argmin)
DEFINE_REDUCTION(ADD_TYPE, sum)
DEFINE_REDUCTION(ELEMENT_ARGMAX_TYPE, argmax)
DEFINE_REDUCTION(ELEMENT_MAX_TYPE, max)
DEFINE_REDUCTION(ELEMENT_MIN_TYPE, min)
DEFINE_REDUCTION(ELEMENT_ARGMIN_TYPE, argmin)
#undef DEFINE_REDUCTION
@@ -789,21 +789,21 @@ namespace detail
math_expression matmatprod(array_base const & A, array_base const & B)
{
shape_t shape{A.shape()[0], B.shape()[1]};
return math_expression(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, OPERATOR_GEMM_NN_TYPE), A.context(), A.dtype(), shape);
return math_expression(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, MATRIX_PRODUCT_NN_TYPE), A.context(), A.dtype(), shape);
}
math_expression matmatprod(math_expression const & A, array_base const & B)
{
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
operation_type type = MATRIX_PRODUCT_NN_TYPE;
shape_t shape{A.shape()[0], B.shape()[1]};
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
bool A_trans = A_root.op.type==TRANS_TYPE;
if(A_trans){
type = OPERATOR_GEMM_TN_TYPE;
type = MATRIX_PRODUCT_TN_TYPE;
}
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs;
return res;
@@ -811,16 +811,16 @@ namespace detail
math_expression matmatprod(array_base const & A, math_expression const & B)
{
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
operation_type type = MATRIX_PRODUCT_NN_TYPE;
shape_t shape{A.shape()[0], B.shape()[1]};
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
bool B_trans = B_root.op.type==TRANS_TYPE;
if(B_trans){
type = OPERATOR_GEMM_NT_TYPE;
type = MATRIX_PRODUCT_NT_TYPE;
}
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
if(B_trans) res_root.rhs = B_root.lhs;
return res;
@@ -828,20 +828,20 @@ namespace detail
math_expression matmatprod(math_expression const & A, math_expression const & B)
{
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
operation_type type = MATRIX_PRODUCT_NN_TYPE;
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
shape_t shape{A.shape()[0], B.shape()[1]};
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
bool A_trans = A_root.op.type==TRANS_TYPE;
bool B_trans = B_root.op.type==TRANS_TYPE;
if(A_trans && B_trans) type = OPERATOR_GEMM_TT_TYPE;
else if(A_trans && !B_trans) type = OPERATOR_GEMM_TN_TYPE;
else if(!A_trans && B_trans) type = OPERATOR_GEMM_NT_TYPE;
else type = OPERATOR_GEMM_NN_TYPE;
if(A_trans && B_trans) type = MATRIX_PRODUCT_TT_TYPE;
else if(A_trans && !B_trans) type = MATRIX_PRODUCT_TN_TYPE;
else if(!A_trans && B_trans) type = MATRIX_PRODUCT_NT_TYPE;
else type = MATRIX_PRODUCT_NN_TYPE;
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs;
if(B_trans) res_root.rhs = B_root.lhs;
@@ -862,14 +862,14 @@ namespace detail
int_t M = A.shape()[0];
int_t N = A.shape()[1];
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
while(A_root.lhs.type_family==COMPOSITE_OPERATOR_FAMILY){
bool A_trans = A_root.op.type==TRANS_TYPE;
while(A_root.lhs.subtype==COMPOSITE_OPERATOR_TYPE){
A_root = A.tree()[A_root.lhs.node_index];
A_trans ^= A_root.op.type==OPERATOR_TRANS_TYPE;
A_trans ^= A_root.op.type==TRANS_TYPE;
}
if(A_trans)
{
math_expression tmp(A, repmat(x, 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), A.context(), A.dtype(), {N, M});
math_expression tmp(A, repmat(x, 1, M), op_element(BINARY_TYPE_FAMILY, ELEMENT_PROD_TYPE), A.context(), A.dtype(), {N, M});
//Remove trans
tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs;
return sum(tmp, 0);
@@ -890,10 +890,10 @@ ISAACAPI void swap(view x, view y)
//Reshape
math_expression reshape(array_base const & x, shape_t const & shape)
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_RESHAPE_TYPE), x.context(), x.dtype(), shape); }
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
math_expression reshape(math_expression const & x, shape_t const & shape)
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_RESHAPE_TYPE), x.context(), x.dtype(), shape); }
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
math_expression ravel(array_base const & x)
{ return reshape(x, {x.shape().prod()}); }
@@ -908,7 +908,7 @@ math_expression dot(LTYPE const & x, RTYPE const & y)\
if(x.dim()==2 && x.shape()[1]==0)\
return zeros(x.shape()[0], y.shape()[1], dtype, context);\
if(x.shape()[0]==0 || (y.dim()==2 && y.shape()[1]==0))\
return math_expression(invalid_node(), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_INVALID_TYPE), context, dtype, {0});\
return math_expression(invalid_node(), invalid_node(), op_element(UNARY_TYPE_FAMILY, INVALID_TYPE), context, dtype, {0});\
if(x.dim()==1 && y.dim()==1)\
return sum(x*y);\
if(x.dim()==2 && x.shape()[0]==1 && y.dim()==1){\
@@ -967,13 +967,13 @@ DEFINE_NORM(math_expression)
math_expression fuse(math_expression const & x, math_expression const & y)
{
assert(x.context()==y.context());
return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape());
return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape());
}
/*--- For loops ---*/
ISAACAPI math_expression sfor(math_expression const & start, math_expression const & end, math_expression const & inc, math_expression const & x)
{
return math_expression(x, make_tuple(x.context(), start, end, inc), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SFOR_TYPE), x.context(), x.dtype(), x.shape());
return math_expression(x, make_tuple(x.context(), start, end, inc), op_element(UNARY_TYPE_FAMILY, SFOR_TYPE), x.context(), x.dtype(), x.shape());
}