Code Quality: heavy renaming and cleaning

This commit is contained in:
Philippe Tillet
2015-12-19 02:04:39 -05:00
parent b6d596d26d
commit bfa7504fc0
22 changed files with 502 additions and 518 deletions

View File

@@ -44,7 +44,7 @@ extern "C" {
* SWAP, SCAL, COPY, AXPY, DOT, DOTU, DOTC, ROTG, ROTMG, ROT, ROTM, iAMAX, ASUM and NRM2,
* BLAS-2 functions GEMV, SYMV, TRMV, TRSV, HEMV, SYR, SYR2, HER, HER2, GER, GERU, GERC,
* TPMV, SPMV, HPMV, TPSV, SPR, SPR2, HPR, HPR2, GBMV, TBMV, SBMV, HBMV and TBSV
* and BLAS-3 functions GEMM, SYMM, TRMM, TRSM, HEMM, HERK, HER2K, SYRK and SYR2K.
* and BLAS-3 functions MATRIX_PRODUCT, SYMM, TRMM, TRSM, HEMM, HERK, HER2K, SYRK and SYR2K.
*
* This librarys primary goal is to assist the end user to enqueue OpenCL
* kernels to process BLAS functions in an OpenCL-efficient manner, while
@@ -7314,7 +7314,7 @@ clblasZtbsv(
/*@}*/
/**
* @defgroup GEMM GEMM - General matrix-matrix multiplication
* @defgroup MATRIX_PRODUCT MATRIX_PRODUCT - General matrix-matrix multiplication
* @ingroup BLAS3
*/
/*@{*/
@@ -7372,7 +7372,7 @@ clblasZtbsv(
* the size of the respective buffer object;
* - the same error codes as clblasSgemm() otherwise.
*
* @ingroup GEMM
* @ingroup MATRIX_PRODUCT
*/
clblasStatus
clblasSgemm(
@@ -7453,7 +7453,7 @@ clblasSgemm(
* the size of the respective buffer object;
* - the same error codes as the clblasSgemm() function otherwise.
*
* @ingroup GEMM
* @ingroup MATRIX_PRODUCT
*/
clblasStatus
clblasDgemm(
@@ -7527,7 +7527,7 @@ clblasDgemm(
* the size of the respective buffer object;
* - the same error codes as the clblasSgemm() function otherwise.
*
* @ingroup GEMM
* @ingroup MATRIX_PRODUCT
*/
clblasStatus
clblasCgemm(
@@ -7603,7 +7603,7 @@ clblasCgemm(
* the size of the respective buffer object;
* - the same error codes as the clblasSgemm() function otherwise.
*
* @ingroup GEMM
* @ingroup MATRIX_PRODUCT
*/
clblasStatus
clblasZgemm(

View File

@@ -418,7 +418,7 @@ void CUBLASWINAPI cublasZhpr2 (char uplo, int n, cuDoubleComplex alpha,
const cuDoubleComplex *x, int incx, const cuDoubleComplex *y,
int incy, cuDoubleComplex *AP);
/* ------------------------BLAS3 Functions ------------------------------- */
/* GEMM */
/* MATRIX_PRODUCT */
void CUBLASWINAPI cublasSgemm (char transa, char transb, int m, int n, int k,
float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C,

View File

@@ -1508,7 +1508,7 @@ CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhpr2_v2 (cublasHandle_t handle,
/* ---------------- CUBLAS BLAS3 functions ---------------- */
/* GEMM */
/* MATRIX_PRODUCT */
CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemm_v2 (cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
@@ -2042,7 +2042,7 @@ CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrmm_v2(cublasHandle_t handle, cubl
int ldb,
cuDoubleComplex *C,
int ldc);
/* BATCH GEMM */
/* BATCH MATRIX_PRODUCT */
CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmBatched (cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,

View File

@@ -248,9 +248,9 @@ public:
class mapped_cast : public mapped_object
{
static std::string operator_to_str(operation_node_type type);
static std::string operator_to_str(operation_type type);
public:
mapped_cast(operation_node_type type, unsigned int id);
mapped_cast(operation_type type, unsigned int id);
};
extern mapped_object& get(math_expression::container_type const &, size_t, mapping_type const &, size_t);

View File

@@ -47,9 +47,9 @@ inline void traverse(isaac::math_expression const & math_expression, std::size_t
//Lhs:
if (recurse)
{
if (root_node.lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.lhs.node_index, fun, inspect);
if (root_node.lhs.type_family != INVALID_TYPE_FAMILY)
if (root_node.lhs.subtype != INVALID_SUBTYPE)
fun(math_expression, root_idx, LHS_NODE_TYPE);
}
@@ -58,11 +58,11 @@ inline void traverse(isaac::math_expression const & math_expression, std::size_t
fun(math_expression, root_idx, PARENT_NODE_TYPE);
//Rhs:
if (recurse && root_node.rhs.type_family!=INVALID_TYPE_FAMILY)
if (recurse && root_node.rhs.subtype!=INVALID_SUBTYPE)
{
if (root_node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.rhs.node_index, fun, inspect);
if (root_node.rhs.type_family != INVALID_TYPE_FAMILY)
if (root_node.rhs.subtype != INVALID_SUBTYPE)
fun(math_expression, root_idx, RHS_NODE_TYPE);
}
@@ -84,10 +84,10 @@ private:
class filter_elements_fun : public traversal_functor
{
public:
filter_elements_fun(math_expression_node_subtype subtype, std::vector<lhs_rhs_element> & out);
filter_elements_fun(node_type subtype, std::vector<lhs_rhs_element> & out);
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t) const;
private:
math_expression_node_subtype subtype_;
node_type subtype_;
std::vector<lhs_rhs_element> & out_;
};
@@ -96,9 +96,9 @@ std::vector<size_t> filter_nodes(bool (*pred)(math_expression::node const & node
size_t root,
bool inspect);
std::vector<lhs_rhs_element> filter_elements(math_expression_node_subtype subtype,
std::vector<lhs_rhs_element> filter_elements(node_type subtype,
isaac::math_expression const & math_expression);
const char * evaluate(operation_node_type type);
const char * evaluate(operation_type type);
/** @brief functor for generating the expression string from a math_expression */
class evaluate_expression_traversal: public traversal_functor

View File

@@ -23,140 +23,121 @@ namespace isaac
class array_base;
/** @brief Optimization enum for grouping operations into unary or binary operations. Just for optimization of lookups. */
enum operation_node_type_family
enum operation_type_family
{
OPERATOR_INVALID_TYPE_FAMILY = 0,
INVALID_TYPE_FAMILY = 0,
// BLAS1-type
OPERATOR_UNARY_TYPE_FAMILY,
OPERATOR_BINARY_TYPE_FAMILY,
OPERATOR_VECTOR_DOT_TYPE_FAMILY,
UNARY_TYPE_FAMILY,
BINARY_TYPE_FAMILY,
VECTOR_DOT_TYPE_FAMILY,
// BLAS2-type
OPERATOR_ROWS_DOT_TYPE_FAMILY,
OPERATOR_COLUMNS_DOT_TYPE_FAMILY,
ROWS_DOT_TYPE_FAMILY,
COLUMNS_DOT_TYPE_FAMILY,
// BLAS3-type
OPERATOR_GEMM_TYPE_FAMILY
MATRIX_PRODUCT_TYPE_FAMILY
};
/** @brief Enumeration for identifying the possible operations */
enum operation_node_type
enum operation_type
{
OPERATOR_INVALID_TYPE = 0,
INVALID_TYPE = 0,
// unary operator
OPERATOR_MINUS_TYPE,
OPERATOR_NEGATE_TYPE,
MINUS_TYPE,
NEGATE_TYPE,
// unary expression
OPERATOR_CAST_BOOL_TYPE,
OPERATOR_CAST_CHAR_TYPE,
OPERATOR_CAST_UCHAR_TYPE,
OPERATOR_CAST_SHORT_TYPE,
OPERATOR_CAST_USHORT_TYPE,
OPERATOR_CAST_INT_TYPE,
OPERATOR_CAST_UINT_TYPE,
OPERATOR_CAST_LONG_TYPE,
OPERATOR_CAST_ULONG_TYPE,
OPERATOR_CAST_HALF_TYPE,
OPERATOR_CAST_FLOAT_TYPE,
OPERATOR_CAST_DOUBLE_TYPE,
CAST_BOOL_TYPE,
CAST_CHAR_TYPE,
CAST_UCHAR_TYPE,
CAST_SHORT_TYPE,
CAST_USHORT_TYPE,
CAST_INT_TYPE,
CAST_UINT_TYPE,
CAST_LONG_TYPE,
CAST_ULONG_TYPE,
CAST_HALF_TYPE,
CAST_FLOAT_TYPE,
CAST_DOUBLE_TYPE,
OPERATOR_ABS_TYPE,
OPERATOR_ACOS_TYPE,
OPERATOR_ASIN_TYPE,
OPERATOR_ATAN_TYPE,
OPERATOR_CEIL_TYPE,
OPERATOR_COS_TYPE,
OPERATOR_COSH_TYPE,
OPERATOR_EXP_TYPE,
OPERATOR_FABS_TYPE,
OPERATOR_FLOOR_TYPE,
OPERATOR_LOG_TYPE,
OPERATOR_LOG10_TYPE,
OPERATOR_SIN_TYPE,
OPERATOR_SINH_TYPE,
OPERATOR_SQRT_TYPE,
OPERATOR_TAN_TYPE,
OPERATOR_TANH_TYPE,
OPERATOR_TRANS_TYPE,
ABS_TYPE,
ACOS_TYPE,
ASIN_TYPE,
ATAN_TYPE,
CEIL_TYPE,
COS_TYPE,
COSH_TYPE,
EXP_TYPE,
FABS_TYPE,
FLOOR_TYPE,
LOG_TYPE,
LOG10_TYPE,
SIN_TYPE,
SINH_TYPE,
SQRT_TYPE,
TAN_TYPE,
TANH_TYPE,
TRANS_TYPE,
// binary expression
OPERATOR_ASSIGN_TYPE,
OPERATOR_INPLACE_ADD_TYPE,
OPERATOR_INPLACE_SUB_TYPE,
OPERATOR_ADD_TYPE,
OPERATOR_SUB_TYPE,
OPERATOR_MULT_TYPE,
OPERATOR_DIV_TYPE,
OPERATOR_ELEMENT_ARGFMAX_TYPE,
OPERATOR_ELEMENT_ARGFMIN_TYPE,
OPERATOR_ELEMENT_ARGMAX_TYPE,
OPERATOR_ELEMENT_ARGMIN_TYPE,
OPERATOR_ELEMENT_PROD_TYPE,
OPERATOR_ELEMENT_DIV_TYPE,
OPERATOR_ELEMENT_EQ_TYPE,
OPERATOR_ELEMENT_NEQ_TYPE,
OPERATOR_ELEMENT_GREATER_TYPE,
OPERATOR_ELEMENT_GEQ_TYPE,
OPERATOR_ELEMENT_LESS_TYPE,
OPERATOR_ELEMENT_LEQ_TYPE,
OPERATOR_ELEMENT_POW_TYPE,
OPERATOR_ELEMENT_FMAX_TYPE,
OPERATOR_ELEMENT_FMIN_TYPE,
OPERATOR_ELEMENT_MAX_TYPE,
OPERATOR_ELEMENT_MIN_TYPE,
ASSIGN_TYPE,
INPLACE_ADD_TYPE,
INPLACE_SUB_TYPE,
ADD_TYPE,
SUB_TYPE,
MULT_TYPE,
DIV_TYPE,
ELEMENT_ARGFMAX_TYPE,
ELEMENT_ARGFMIN_TYPE,
ELEMENT_ARGMAX_TYPE,
ELEMENT_ARGMIN_TYPE,
ELEMENT_PROD_TYPE,
ELEMENT_DIV_TYPE,
ELEMENT_EQ_TYPE,
ELEMENT_NEQ_TYPE,
ELEMENT_GREATER_TYPE,
ELEMENT_GEQ_TYPE,
ELEMENT_LESS_TYPE,
ELEMENT_LEQ_TYPE,
ELEMENT_POW_TYPE,
ELEMENT_FMAX_TYPE,
ELEMENT_FMIN_TYPE,
ELEMENT_MAX_TYPE,
ELEMENT_MIN_TYPE,
//Products
OPERATOR_OUTER_PROD_TYPE,
OPERATOR_GEMM_NN_TYPE,
OPERATOR_GEMM_TN_TYPE,
OPERATOR_GEMM_NT_TYPE,
OPERATOR_GEMM_TT_TYPE,
OUTER_PROD_TYPE,
MATRIX_PRODUCT_NN_TYPE,
MATRIX_PRODUCT_TN_TYPE,
MATRIX_PRODUCT_NT_TYPE,
MATRIX_PRODUCT_TT_TYPE,
//Access modifiers
OPERATOR_MATRIX_DIAG_TYPE,
OPERATOR_MATRIX_ROW_TYPE,
OPERATOR_MATRIX_COLUMN_TYPE,
OPERATOR_REPEAT_TYPE,
OPERATOR_RESHAPE_TYPE,
OPERATOR_SHIFT_TYPE,
OPERATOR_VDIAG_TYPE,
OPERATOR_ACCESS_INDEX_TYPE,
MATRIX_DIAG_TYPE,
MATRIX_ROW_TYPE,
MATRIX_COLUMN_TYPE,
REPEAT_TYPE,
RESHAPE_TYPE,
SHIFT_TYPE,
VDIAG_TYPE,
ACCESS_INDEX_TYPE,
OPERATOR_PAIR_TYPE,
PAIR_TYPE,
OPERATOR_FUSE,
OPERATOR_SFOR_TYPE,
};
/** @brief Groups the type of a node in the math_expression tree. Used for faster dispatching */
enum math_expression_node_type_family
{
INVALID_TYPE_FAMILY = 0,
COMPOSITE_OPERATOR_FAMILY,
VALUE_TYPE_FAMILY,
ARRAY_TYPE_FAMILY,
PLACEHOLDER_TYPE_FAMILY
};
/** @brief Encodes the type of a node in the math_expression tree. */
enum math_expression_node_subtype
{
INVALID_SUBTYPE = 0,
VALUE_SCALAR_TYPE,
DENSE_ARRAY_TYPE,
FOR_LOOP_INDEX_TYPE
SFOR_TYPE,
};
struct op_element
{
op_element();
op_element(operation_node_type_family const & _type_family, operation_node_type const & _type);
operation_node_type_family type_family;
operation_node_type type;
op_element(operation_type_family const & _type_family, operation_type const & _type);
operation_type_family type_family;
operation_type type;
};
struct for_idx_t
@@ -172,11 +153,19 @@ struct for_idx_t
int level;
};
enum node_type
{
INVALID_SUBTYPE = 0,
COMPOSITE_OPERATOR_TYPE,
VALUE_SCALAR_TYPE,
DENSE_ARRAY_TYPE,
FOR_LOOP_INDEX_TYPE
};
struct lhs_rhs_element
{
lhs_rhs_element();
math_expression_node_type_family type_family;
math_expression_node_subtype subtype;
node_type subtype;
numeric_type dtype;
union
{

View File

@@ -7,7 +7,7 @@
namespace isaac
{
std::string to_string(math_expression_node_subtype const & f);
std::string to_string(node_type const & f);
std::string to_string(lhs_rhs_element const & e);
std::ostream & operator<<(std::ostream & os, math_expression::node const & s_node);
std::string to_string(isaac::math_expression const & s);

View File

@@ -19,13 +19,13 @@ ISAACAPI typename std::conditional<std::is_arithmetic<T>::value, value_scalar, T
template<typename T, typename... Args>
ISAACAPI math_expression make_tuple(driver::Context const & context, T const & x, Args... args)
{ return math_expression(wrap_generic(x), make_tuple(context, args...), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_PAIR_TYPE), context, numeric_type_of(x), {1}); }
{ return math_expression(wrap_generic(x), make_tuple(context, args...), op_element(BINARY_TYPE_FAMILY, PAIR_TYPE), context, numeric_type_of(x), {1}); }
inline value_scalar tuple_get(math_expression::container_type const & tree, size_t root, size_t idx)
{
for(unsigned int i = 0 ; i < idx ; ++i){
math_expression::node node = tree[root];
if(node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
if(node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
root = node.rhs.node_index;
else
return value_scalar(node.rhs.vscalar, node.rhs.dtype);

View File

@@ -169,7 +169,7 @@ array_base & array_base::operator=(array_base const & rhs)
{
if(shape_.min()==0) return *this;
assert(dtype_ == rhs.dtype());
math_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
math_expression expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
execute(execution_handler(expression));
return *this;
}
@@ -178,7 +178,7 @@ array_base & array_base::operator=(value_scalar const & rhs)
{
if(shape_.min()==0) return *this;
assert(dtype_ == rhs.dtype());
math_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
math_expression expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
execute(execution_handler(expression));
return *this;
}
@@ -188,7 +188,7 @@ array_base& array_base::operator=(execution_handler const & c)
{
if(shape_.min()==0) return *this;
assert(dtype_ == c.x().dtype());
math_expression expression(*this, c.x(), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
math_expression expression(*this, c.x(), op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
return *this;
}
@@ -228,53 +228,53 @@ INSTANTIATE(double);
math_expression array_base::operator-()
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
math_expression array_base::operator!()
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_NEGATE_TYPE), context_, INT_TYPE, shape_); }
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), context_, INT_TYPE, shape_); }
//
array_base & array_base::operator+=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator+=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator+=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), rhs.context(), dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), rhs.context(), dtype_, shape_); }
//----
array_base & array_base::operator-=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator-=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator-=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), rhs.context(), dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), rhs.context(), dtype_, shape_); }
//----
array_base & array_base::operator*=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator*=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator*=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), rhs.context(), dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), rhs.context(), dtype_, shape_); }
//----
array_base & array_base::operator/=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator/=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator/=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), rhs.context(), dtype_, shape_); }
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), rhs.context(), dtype_, shape_); }
/*--- Indexing operators -----*/
//---------------------------------------
math_expression array_base::operator[](for_idx_t idx) const
{
return math_expression(*this, idx, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ACCESS_INDEX_TYPE), context_, dtype_, {1});
return math_expression(*this, idx, op_element(BINARY_TYPE_FAMILY, ACCESS_INDEX_TYPE), context_, dtype_, {1});
}
scalar array_base::operator [](int_t idx)
@@ -512,72 +512,72 @@ shape_t broadcast(shape_t const & a, shape_t const & b)
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \
math_expression OPNAME (array_base const & x, math_expression const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
\
math_expression OPNAME (array_base const & x, array_base const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); }\
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); }\
\
math_expression OPNAME (array_base const & x, value_scalar const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\
math_expression OPNAME (array_base const & x, for_idx_t const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\
\
math_expression OPNAME (math_expression const & x, math_expression const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
\
math_expression OPNAME (math_expression const & x, array_base const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
\
math_expression OPNAME (math_expression const & x, value_scalar const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\
math_expression OPNAME (math_expression const & x, for_idx_t const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\
\
math_expression OPNAME (value_scalar const & y, math_expression const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\
math_expression OPNAME (value_scalar const & y, array_base const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\
math_expression OPNAME (value_scalar const & x, for_idx_t const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE); }\
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); }\
\
\
math_expression OPNAME (for_idx_t const & y, math_expression const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\
math_expression OPNAME (for_idx_t const & y, value_scalar const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE); } \
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); } \
\
math_expression OPNAME (for_idx_t const & y, array_base const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\
math_expression OPNAME (for_idx_t const & y, for_idx_t const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP)); }
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP)); }
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ADD_TYPE, operator +, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_SUB_TYPE, operator -, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_MULT_TYPE, operator *, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_DIV_TYPE, operator /, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(ADD_TYPE, operator +, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(SUB_TYPE, operator -, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(MULT_TYPE, operator *, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(DIV_TYPE, operator /, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, maximum, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, minimum, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_MAX_TYPE, maximum, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_MIN_TYPE, minimum, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_POW_TYPE, pow, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ASSIGN_TYPE, assign, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(ASSIGN_TYPE, assign, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GREATER_TYPE, operator >, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GEQ_TYPE, operator >=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LESS_TYPE, operator <, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LEQ_TYPE, operator <=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_GREATER_TYPE, operator >, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_GEQ_TYPE, operator >=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_LESS_TYPE, operator <, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_LEQ_TYPE, operator <=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_EQ_TYPE, operator ==, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_NEQ_TYPE, operator !=, INT_TYPE)
#define DEFINE_OUTER(LTYPE, RTYPE) \
math_expression outer(LTYPE const & x, RTYPE const & y)\
@@ -585,7 +585,7 @@ math_expression outer(LTYPE const & x, RTYPE const & y)\
assert(x.dim()<=1 && y.dim()<=1);\
if(x.dim()<1 || y.dim()<1)\
return x*y;\
return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_OUTER_PROD_TYPE), x.context(), x.dtype(), {x.shape().max(), y.shape().max()} );\
return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OUTER_PROD_TYPE), x.context(), x.dtype(), {x.shape().max(), y.shape().max()} );\
}\
DEFINE_OUTER(array_base, array_base)
@@ -622,61 +622,61 @@ DEFINE_ROT(math_expression, math_expression, math_expression, math_expression)
//---------------------------------------
#define DEFINE_ELEMENT_UNARY_OPERATOR(OP, OPNAME) \
math_expression OPNAME (array_base const & x) \
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
\
math_expression OPNAME (math_expression const & x) \
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?OPERATOR_FABS_TYPE:OPERATOR_ABS_TYPE, abs)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ACOS_TYPE, acos)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ASIN_TYPE, asin)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ATAN_TYPE, atan)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_CEIL_TYPE, ceil)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_COS_TYPE, cos)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_COSH_TYPE, cosh)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_EXP_TYPE, exp)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_FLOOR_TYPE, floor)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_LOG_TYPE, log)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_LOG10_TYPE,log10)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_SIN_TYPE, sin)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_SINH_TYPE, sinh)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_SQRT_TYPE, sqrt)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_TAN_TYPE, tan)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_TANH_TYPE, tanh)
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?FABS_TYPE:ABS_TYPE, abs)
DEFINE_ELEMENT_UNARY_OPERATOR(ACOS_TYPE, acos)
DEFINE_ELEMENT_UNARY_OPERATOR(ASIN_TYPE, asin)
DEFINE_ELEMENT_UNARY_OPERATOR(ATAN_TYPE, atan)
DEFINE_ELEMENT_UNARY_OPERATOR(CEIL_TYPE, ceil)
DEFINE_ELEMENT_UNARY_OPERATOR(COS_TYPE, cos)
DEFINE_ELEMENT_UNARY_OPERATOR(COSH_TYPE, cosh)
DEFINE_ELEMENT_UNARY_OPERATOR(EXP_TYPE, exp)
DEFINE_ELEMENT_UNARY_OPERATOR(FLOOR_TYPE, floor)
DEFINE_ELEMENT_UNARY_OPERATOR(LOG_TYPE, log)
DEFINE_ELEMENT_UNARY_OPERATOR(LOG10_TYPE,log10)
DEFINE_ELEMENT_UNARY_OPERATOR(SIN_TYPE, sin)
DEFINE_ELEMENT_UNARY_OPERATOR(SINH_TYPE, sinh)
DEFINE_ELEMENT_UNARY_OPERATOR(SQRT_TYPE, sqrt)
DEFINE_ELEMENT_UNARY_OPERATOR(TAN_TYPE, tan)
DEFINE_ELEMENT_UNARY_OPERATOR(TANH_TYPE, tanh)
#undef DEFINE_ELEMENT_UNARY_OPERATOR
//---------------------------------------
///*--- Misc----*/
////---------------------------------------
inline operation_node_type casted(numeric_type dtype)
inline operation_type casted(numeric_type dtype)
{
switch(dtype)
{
// case BOOL_TYPE: return OPERATOR_CAST_BOOL_TYPE;
case CHAR_TYPE: return OPERATOR_CAST_CHAR_TYPE;
case UCHAR_TYPE: return OPERATOR_CAST_UCHAR_TYPE;
case SHORT_TYPE: return OPERATOR_CAST_SHORT_TYPE;
case USHORT_TYPE: return OPERATOR_CAST_USHORT_TYPE;
case INT_TYPE: return OPERATOR_CAST_INT_TYPE;
case UINT_TYPE: return OPERATOR_CAST_UINT_TYPE;
case LONG_TYPE: return OPERATOR_CAST_LONG_TYPE;
case ULONG_TYPE: return OPERATOR_CAST_ULONG_TYPE;
// case FLOAT_TYPE: return OPERATOR_CAST_HALF_TYPE;
case FLOAT_TYPE: return OPERATOR_CAST_FLOAT_TYPE;
case DOUBLE_TYPE: return OPERATOR_CAST_DOUBLE_TYPE;
// case BOOL_TYPE: return CAST_BOOL_TYPE;
case CHAR_TYPE: return CAST_CHAR_TYPE;
case UCHAR_TYPE: return CAST_UCHAR_TYPE;
case SHORT_TYPE: return CAST_SHORT_TYPE;
case USHORT_TYPE: return CAST_USHORT_TYPE;
case INT_TYPE: return CAST_INT_TYPE;
case UINT_TYPE: return CAST_UINT_TYPE;
case LONG_TYPE: return CAST_LONG_TYPE;
case ULONG_TYPE: return CAST_ULONG_TYPE;
// case FLOAT_TYPE: return CAST_HALF_TYPE;
case FLOAT_TYPE: return CAST_FLOAT_TYPE;
case DOUBLE_TYPE: return CAST_DOUBLE_TYPE;
default: throw unknown_datatype(dtype);
}
}
math_expression cast(array_base const & x, numeric_type dtype)
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
math_expression cast(math_expression const & x, numeric_type dtype)
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return math_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, {M, N}); }
{ return math_expression(value_scalar(1), value_scalar(0), op_element(UNARY_TYPE_FAMILY, VDIAG_TYPE), ctx, dtype, {M, N}); }
array diag(array_base & x, int offset)
{
@@ -689,7 +689,7 @@ array diag(array_base & x, int offset)
isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, {M, N}); }
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(UNARY_TYPE_FAMILY, ADD_TYPE), ctx, dtype, {M, N}); }
inline shape_t flip(shape_t const & shape)
{
@@ -703,28 +703,28 @@ inline shape_t flip(shape_t const & shape)
//{ return size4(shape1[0]*shape2[0], shape1[1]*shape2[1]);}
math_expression trans(array_base const & x) \
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
\
math_expression trans(math_expression const & x) \
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
math_expression repmat(array_base const & A, int_t const & rep1, int_t const & rep2)
{
int_t sub1 = A.shape()[0];
int_t sub2 = A.dim()==2?A.shape()[1]:1;
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
}
math_expression repmat(math_expression const & A, int_t const & rep1, int_t const & rep2)
{
int_t sub1 = A.shape()[0];
int_t sub2 = A.dim()==2?A.shape()[1]:1;
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
}
#define DEFINE_ACCESS_ROW(TYPEA, TYPEB) \
math_expression row(TYPEA const & x, TYPEB const & i)\
{ return math_expression(x, i, op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_MATRIX_ROW_TYPE), x.context(), x.dtype(), {x.shape()[1]}); }
{ return math_expression(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_ROW_TYPE), x.context(), x.dtype(), {x.shape()[1]}); }
DEFINE_ACCESS_ROW(array_base, value_scalar)
DEFINE_ACCESS_ROW(array_base, for_idx_t)
@@ -736,7 +736,7 @@ DEFINE_ACCESS_ROW(math_expression, math_expression)
#define DEFINE_ACCESS_COL(TYPEA, TYPEB) \
math_expression col(TYPEA const & x, TYPEB const & i)\
{ return math_expression(x, i, op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_MATRIX_COLUMN_TYPE), x.context(), x.dtype(), {x.shape()[0]}); }
{ return math_expression(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_COLUMN_TYPE), x.context(), x.dtype(), {x.shape()[0]}); }
DEFINE_ACCESS_COL(array_base, value_scalar)
DEFINE_ACCESS_COL(array_base, for_idx_t)
@@ -756,11 +756,11 @@ math_expression OPNAME(array_base const & x, int_t axis)\
if(axis < -1 || axis > x.dim())\
throw std::out_of_range("The axis entry is out of bounds");\
else if(axis==-1)\
return math_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
return math_expression(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
else if(axis==0)\
return math_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
return math_expression(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
else\
return math_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
return math_expression(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
}\
\
math_expression OPNAME(math_expression const & x, int_t axis)\
@@ -768,18 +768,18 @@ math_expression OPNAME(math_expression const & x, int_t axis)\
if(axis < -1 || axis > x.dim())\
throw std::out_of_range("The axis entry is out of bounds");\
if(axis==-1)\
return math_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
return math_expression(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
else if(axis==0)\
return math_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
return math_expression(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
else\
return math_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
return math_expression(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
}
DEFINE_REDUCTION(OPERATOR_ADD_TYPE, sum)
DEFINE_REDUCTION(OPERATOR_ELEMENT_ARGMAX_TYPE, argmax)
DEFINE_REDUCTION(OPERATOR_ELEMENT_MAX_TYPE, max)
DEFINE_REDUCTION(OPERATOR_ELEMENT_MIN_TYPE, min)
DEFINE_REDUCTION(OPERATOR_ELEMENT_ARGMIN_TYPE, argmin)
DEFINE_REDUCTION(ADD_TYPE, sum)
DEFINE_REDUCTION(ELEMENT_ARGMAX_TYPE, argmax)
DEFINE_REDUCTION(ELEMENT_MAX_TYPE, max)
DEFINE_REDUCTION(ELEMENT_MIN_TYPE, min)
DEFINE_REDUCTION(ELEMENT_ARGMIN_TYPE, argmin)
#undef DEFINE_REDUCTION
@@ -789,21 +789,21 @@ namespace detail
math_expression matmatprod(array_base const & A, array_base const & B)
{
shape_t shape{A.shape()[0], B.shape()[1]};
return math_expression(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, OPERATOR_GEMM_NN_TYPE), A.context(), A.dtype(), shape);
return math_expression(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, MATRIX_PRODUCT_NN_TYPE), A.context(), A.dtype(), shape);
}
math_expression matmatprod(math_expression const & A, array_base const & B)
{
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
operation_type type = MATRIX_PRODUCT_NN_TYPE;
shape_t shape{A.shape()[0], B.shape()[1]};
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
bool A_trans = A_root.op.type==TRANS_TYPE;
if(A_trans){
type = OPERATOR_GEMM_TN_TYPE;
type = MATRIX_PRODUCT_TN_TYPE;
}
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs;
return res;
@@ -811,16 +811,16 @@ namespace detail
math_expression matmatprod(array_base const & A, math_expression const & B)
{
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
operation_type type = MATRIX_PRODUCT_NN_TYPE;
shape_t shape{A.shape()[0], B.shape()[1]};
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
bool B_trans = B_root.op.type==TRANS_TYPE;
if(B_trans){
type = OPERATOR_GEMM_NT_TYPE;
type = MATRIX_PRODUCT_NT_TYPE;
}
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
if(B_trans) res_root.rhs = B_root.lhs;
return res;
@@ -828,20 +828,20 @@ namespace detail
math_expression matmatprod(math_expression const & A, math_expression const & B)
{
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
operation_type type = MATRIX_PRODUCT_NN_TYPE;
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
shape_t shape{A.shape()[0], B.shape()[1]};
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
bool A_trans = A_root.op.type==TRANS_TYPE;
bool B_trans = B_root.op.type==TRANS_TYPE;
if(A_trans && B_trans) type = OPERATOR_GEMM_TT_TYPE;
else if(A_trans && !B_trans) type = OPERATOR_GEMM_TN_TYPE;
else if(!A_trans && B_trans) type = OPERATOR_GEMM_NT_TYPE;
else type = OPERATOR_GEMM_NN_TYPE;
if(A_trans && B_trans) type = MATRIX_PRODUCT_TT_TYPE;
else if(A_trans && !B_trans) type = MATRIX_PRODUCT_TN_TYPE;
else if(!A_trans && B_trans) type = MATRIX_PRODUCT_NT_TYPE;
else type = MATRIX_PRODUCT_NN_TYPE;
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs;
if(B_trans) res_root.rhs = B_root.lhs;
@@ -862,14 +862,14 @@ namespace detail
int_t M = A.shape()[0];
int_t N = A.shape()[1];
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
while(A_root.lhs.type_family==COMPOSITE_OPERATOR_FAMILY){
bool A_trans = A_root.op.type==TRANS_TYPE;
while(A_root.lhs.subtype==COMPOSITE_OPERATOR_TYPE){
A_root = A.tree()[A_root.lhs.node_index];
A_trans ^= A_root.op.type==OPERATOR_TRANS_TYPE;
A_trans ^= A_root.op.type==TRANS_TYPE;
}
if(A_trans)
{
math_expression tmp(A, repmat(x, 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), A.context(), A.dtype(), {N, M});
math_expression tmp(A, repmat(x, 1, M), op_element(BINARY_TYPE_FAMILY, ELEMENT_PROD_TYPE), A.context(), A.dtype(), {N, M});
//Remove trans
tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs;
return sum(tmp, 0);
@@ -890,10 +890,10 @@ ISAACAPI void swap(view x, view y)
//Reshape
math_expression reshape(array_base const & x, shape_t const & shape)
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_RESHAPE_TYPE), x.context(), x.dtype(), shape); }
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
math_expression reshape(math_expression const & x, shape_t const & shape)
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_RESHAPE_TYPE), x.context(), x.dtype(), shape); }
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
math_expression ravel(array_base const & x)
{ return reshape(x, {x.shape().prod()}); }
@@ -908,7 +908,7 @@ math_expression dot(LTYPE const & x, RTYPE const & y)\
if(x.dim()==2 && x.shape()[1]==0)\
return zeros(x.shape()[0], y.shape()[1], dtype, context);\
if(x.shape()[0]==0 || (y.dim()==2 && y.shape()[1]==0))\
return math_expression(invalid_node(), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_INVALID_TYPE), context, dtype, {0});\
return math_expression(invalid_node(), invalid_node(), op_element(UNARY_TYPE_FAMILY, INVALID_TYPE), context, dtype, {0});\
if(x.dim()==1 && y.dim()==1)\
return sum(x*y);\
if(x.dim()==2 && x.shape()[0]==1 && y.dim()==1){\
@@ -967,13 +967,13 @@ DEFINE_NORM(math_expression)
math_expression fuse(math_expression const & x, math_expression const & y)
{
assert(x.context()==y.context());
return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape());
return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape());
}
/*--- For loops ---*/
ISAACAPI math_expression sfor(math_expression const & start, math_expression const & end, math_expression const & inc, math_expression const & x)
{
return math_expression(x, make_tuple(x.context(), start, end, inc), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SFOR_TYPE), x.context(), x.dtype(), x.shape());
return math_expression(x, make_tuple(x.context(), start, end, inc), op_element(UNARY_TYPE_FAMILY, SFOR_TYPE), x.context(), x.dtype(), x.shape());
}

View File

@@ -96,7 +96,7 @@ mapped_object& get(math_expression::container_type const & tree, size_t root, ma
{
for(unsigned int i = 0 ; i < idx ; ++i){
math_expression::node node = tree[root];
if(node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
if(node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
root = node.rhs.node_index;
else
return *(mapping.at(std::make_pair(root, RHS_NODE_TYPE)));
@@ -136,10 +136,10 @@ math_expression::node mapped_reduce::root_node() const
bool mapped_reduce::is_index_reduction() const
{
op_element const & op = root_op();
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
|| op.type==OPERATOR_ELEMENT_ARGMAX_TYPE
|| op.type==OPERATOR_ELEMENT_ARGFMIN_TYPE
|| op.type==OPERATOR_ELEMENT_ARGMIN_TYPE;
return op.type==ELEMENT_ARGFMAX_TYPE
|| op.type==ELEMENT_ARGMAX_TYPE
|| op.type==ELEMENT_ARGFMIN_TYPE
|| op.type==ELEMENT_ARGMIN_TYPE;
}
op_element mapped_reduce::root_op() const
@@ -358,27 +358,27 @@ void mapped_outer::postprocess(std::string &res) const
mapped_outer::mapped_outer(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "outer"), binary_leaf(info)
{ }
std::string mapped_cast::operator_to_str(operation_node_type type)
std::string mapped_cast::operator_to_str(operation_type type)
{
switch(type)
{
case OPERATOR_CAST_BOOL_TYPE : return "bool";
case OPERATOR_CAST_CHAR_TYPE : return "char";
case OPERATOR_CAST_UCHAR_TYPE : return "uchar";
case OPERATOR_CAST_SHORT_TYPE : return "short";
case OPERATOR_CAST_USHORT_TYPE : return "ushort";
case OPERATOR_CAST_INT_TYPE : return "int";
case OPERATOR_CAST_UINT_TYPE : return "uint";
case OPERATOR_CAST_LONG_TYPE : return "long";
case OPERATOR_CAST_ULONG_TYPE : return "ulong";
case OPERATOR_CAST_HALF_TYPE : return "half";
case OPERATOR_CAST_FLOAT_TYPE : return "float";
case OPERATOR_CAST_DOUBLE_TYPE : return "double";
case CAST_BOOL_TYPE : return "bool";
case CAST_CHAR_TYPE : return "char";
case CAST_UCHAR_TYPE : return "uchar";
case CAST_SHORT_TYPE : return "short";
case CAST_USHORT_TYPE : return "ushort";
case CAST_INT_TYPE : return "int";
case CAST_UINT_TYPE : return "uint";
case CAST_LONG_TYPE : return "long";
case CAST_ULONG_TYPE : return "ulong";
case CAST_HALF_TYPE : return "half";
case CAST_FLOAT_TYPE : return "float";
case CAST_DOUBLE_TYPE : return "double";
default : return "invalid";
}
}
mapped_cast::mapped_cast(operation_node_type type, unsigned int id) : mapped_object(operator_to_str(type), id, "cast")
mapped_cast::mapped_cast(operation_type type, unsigned int id) : mapped_object(operator_to_str(type), id, "cast")
{ }

View File

@@ -16,103 +16,103 @@ namespace detail
bool is_scalar_reduce_1d(math_expression::node const & node)
{
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY;
return node.op.type_family==VECTOR_DOT_TYPE_FAMILY;
}
bool is_vector_reduce_1d(math_expression::node const & node)
{
return node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY;
return node.op.type_family==ROWS_DOT_TYPE_FAMILY
|| node.op.type_family==COLUMNS_DOT_TYPE_FAMILY;
}
bool is_assignment(op_element const & op)
{
return op.type== OPERATOR_ASSIGN_TYPE
|| op.type== OPERATOR_INPLACE_ADD_TYPE
|| op.type== OPERATOR_INPLACE_SUB_TYPE;
return op.type== ASSIGN_TYPE
|| op.type== INPLACE_ADD_TYPE
|| op.type== INPLACE_SUB_TYPE;
}
bool is_elementwise_operator(op_element const & op)
{
return is_assignment(op)
|| op.type== OPERATOR_ADD_TYPE
|| op.type== OPERATOR_SUB_TYPE
|| op.type== OPERATOR_ELEMENT_PROD_TYPE
|| op.type== OPERATOR_ELEMENT_DIV_TYPE
|| op.type== OPERATOR_MULT_TYPE
|| op.type== OPERATOR_DIV_TYPE
|| op.type== OPERATOR_ELEMENT_EQ_TYPE
|| op.type== OPERATOR_ELEMENT_NEQ_TYPE
|| op.type== OPERATOR_ELEMENT_GREATER_TYPE
|| op.type== OPERATOR_ELEMENT_LESS_TYPE
|| op.type== OPERATOR_ELEMENT_GEQ_TYPE
|| op.type== OPERATOR_ELEMENT_LEQ_TYPE ;
|| op.type== ADD_TYPE
|| op.type== SUB_TYPE
|| op.type== ELEMENT_PROD_TYPE
|| op.type== ELEMENT_DIV_TYPE
|| op.type== MULT_TYPE
|| op.type== DIV_TYPE
|| op.type== ELEMENT_EQ_TYPE
|| op.type== ELEMENT_NEQ_TYPE
|| op.type== ELEMENT_GREATER_TYPE
|| op.type== ELEMENT_LESS_TYPE
|| op.type== ELEMENT_GEQ_TYPE
|| op.type== ELEMENT_LEQ_TYPE ;
}
bool bypass(op_element const & op)
{
return op.type == OPERATOR_RESHAPE_TYPE
||op.type == OPERATOR_TRANS_TYPE;
return op.type == RESHAPE_TYPE
||op.type == TRANS_TYPE;
}
bool is_cast(op_element const & op)
{
return op.type== OPERATOR_CAST_BOOL_TYPE
|| op.type== OPERATOR_CAST_CHAR_TYPE
|| op.type== OPERATOR_CAST_UCHAR_TYPE
|| op.type== OPERATOR_CAST_SHORT_TYPE
|| op.type== OPERATOR_CAST_USHORT_TYPE
|| op.type== OPERATOR_CAST_INT_TYPE
|| op.type== OPERATOR_CAST_UINT_TYPE
|| op.type== OPERATOR_CAST_LONG_TYPE
|| op.type== OPERATOR_CAST_ULONG_TYPE
|| op.type== OPERATOR_CAST_FLOAT_TYPE
|| op.type== OPERATOR_CAST_DOUBLE_TYPE
return op.type== CAST_BOOL_TYPE
|| op.type== CAST_CHAR_TYPE
|| op.type== CAST_UCHAR_TYPE
|| op.type== CAST_SHORT_TYPE
|| op.type== CAST_USHORT_TYPE
|| op.type== CAST_INT_TYPE
|| op.type== CAST_UINT_TYPE
|| op.type== CAST_LONG_TYPE
|| op.type== CAST_ULONG_TYPE
|| op.type== CAST_FLOAT_TYPE
|| op.type== CAST_DOUBLE_TYPE
;
}
bool is_node_leaf(op_element const & op)
{
return op.type==OPERATOR_MATRIX_DIAG_TYPE
|| op.type==OPERATOR_VDIAG_TYPE
|| op.type==OPERATOR_REPEAT_TYPE
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|| op.type==OPERATOR_ACCESS_INDEX_TYPE
|| op.type==OPERATOR_OUTER_PROD_TYPE
|| op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_GEMM_TYPE_FAMILY
return op.type==MATRIX_DIAG_TYPE
|| op.type==VDIAG_TYPE
|| op.type==REPEAT_TYPE
|| op.type==MATRIX_ROW_TYPE
|| op.type==MATRIX_COLUMN_TYPE
|| op.type==ACCESS_INDEX_TYPE
|| op.type==OUTER_PROD_TYPE
|| op.type_family==VECTOR_DOT_TYPE_FAMILY
|| op.type_family==ROWS_DOT_TYPE_FAMILY
|| op.type_family==COLUMNS_DOT_TYPE_FAMILY
|| op.type_family==MATRIX_PRODUCT_TYPE_FAMILY
;
}
bool is_elementwise_function(op_element const & op)
{
return is_cast(op)
|| op.type== OPERATOR_ABS_TYPE
|| op.type== OPERATOR_ACOS_TYPE
|| op.type== OPERATOR_ASIN_TYPE
|| op.type== OPERATOR_ATAN_TYPE
|| op.type== OPERATOR_CEIL_TYPE
|| op.type== OPERATOR_COS_TYPE
|| op.type== OPERATOR_COSH_TYPE
|| op.type== OPERATOR_EXP_TYPE
|| op.type== OPERATOR_FABS_TYPE
|| op.type== OPERATOR_FLOOR_TYPE
|| op.type== OPERATOR_LOG_TYPE
|| op.type== OPERATOR_LOG10_TYPE
|| op.type== OPERATOR_SIN_TYPE
|| op.type== OPERATOR_SINH_TYPE
|| op.type== OPERATOR_SQRT_TYPE
|| op.type== OPERATOR_TAN_TYPE
|| op.type== OPERATOR_TANH_TYPE
|| op.type== ABS_TYPE
|| op.type== ACOS_TYPE
|| op.type== ASIN_TYPE
|| op.type== ATAN_TYPE
|| op.type== CEIL_TYPE
|| op.type== COS_TYPE
|| op.type== COSH_TYPE
|| op.type== EXP_TYPE
|| op.type== FABS_TYPE
|| op.type== FLOOR_TYPE
|| op.type== LOG_TYPE
|| op.type== LOG10_TYPE
|| op.type== SIN_TYPE
|| op.type== SINH_TYPE
|| op.type== SQRT_TYPE
|| op.type== TAN_TYPE
|| op.type== TANH_TYPE
|| op.type== OPERATOR_ELEMENT_POW_TYPE
|| op.type== OPERATOR_ELEMENT_FMAX_TYPE
|| op.type== OPERATOR_ELEMENT_FMIN_TYPE
|| op.type== OPERATOR_ELEMENT_MAX_TYPE
|| op.type== OPERATOR_ELEMENT_MIN_TYPE;
|| op.type== ELEMENT_POW_TYPE
|| op.type== ELEMENT_FMAX_TYPE
|| op.type== ELEMENT_FMIN_TYPE
|| op.type== ELEMENT_MAX_TYPE
|| op.type== ELEMENT_MIN_TYPE;
}
@@ -139,7 +139,7 @@ std::vector<size_t> filter_nodes(bool (*pred)(math_expression::node const & node
}
//
filter_elements_fun::filter_elements_fun(math_expression_node_subtype subtype, std::vector<lhs_rhs_element> & out) :
filter_elements_fun::filter_elements_fun(node_type subtype, std::vector<lhs_rhs_element> & out) :
subtype_(subtype), out_(out)
{ }
@@ -153,84 +153,84 @@ void filter_elements_fun::operator()(isaac::math_expression const & math_express
}
std::vector<lhs_rhs_element> filter_elements(math_expression_node_subtype subtype, isaac::math_expression const & math_expression)
std::vector<lhs_rhs_element> filter_elements(node_type subtype, isaac::math_expression const & math_expression)
{
std::vector<lhs_rhs_element> res;
traverse(math_expression, math_expression.root(), filter_elements_fun(subtype, res), true);
return res;
}
/** @brief generate a string from an operation_node_type */
const char * evaluate(operation_node_type type)
/** @brief generate a string from an operation_type */
const char * evaluate(operation_type type)
{
// unary expression
switch (type)
{
//Function
case OPERATOR_ABS_TYPE : return "abs";
case OPERATOR_ACOS_TYPE : return "acos";
case OPERATOR_ASIN_TYPE : return "asin";
case OPERATOR_ATAN_TYPE : return "atan";
case OPERATOR_CEIL_TYPE : return "ceil";
case OPERATOR_COS_TYPE : return "cos";
case OPERATOR_COSH_TYPE : return "cosh";
case OPERATOR_EXP_TYPE : return "exp";
case OPERATOR_FABS_TYPE : return "fabs";
case OPERATOR_FLOOR_TYPE : return "floor";
case OPERATOR_LOG_TYPE : return "log";
case OPERATOR_LOG10_TYPE : return "log10";
case OPERATOR_SIN_TYPE : return "sin";
case OPERATOR_SINH_TYPE : return "sinh";
case OPERATOR_SQRT_TYPE : return "sqrt";
case OPERATOR_TAN_TYPE : return "tan";
case OPERATOR_TANH_TYPE : return "tanh";
case ABS_TYPE : return "abs";
case ACOS_TYPE : return "acos";
case ASIN_TYPE : return "asin";
case ATAN_TYPE : return "atan";
case CEIL_TYPE : return "ceil";
case COS_TYPE : return "cos";
case COSH_TYPE : return "cosh";
case EXP_TYPE : return "exp";
case FABS_TYPE : return "fabs";
case FLOOR_TYPE : return "floor";
case LOG_TYPE : return "log";
case LOG10_TYPE : return "log10";
case SIN_TYPE : return "sin";
case SINH_TYPE : return "sinh";
case SQRT_TYPE : return "sqrt";
case TAN_TYPE : return "tan";
case TANH_TYPE : return "tanh";
case OPERATOR_ELEMENT_ARGFMAX_TYPE : return "argfmax";
case OPERATOR_ELEMENT_ARGMAX_TYPE : return "argmax";
case OPERATOR_ELEMENT_ARGFMIN_TYPE : return "argfmin";
case OPERATOR_ELEMENT_ARGMIN_TYPE : return "argmin";
case OPERATOR_ELEMENT_POW_TYPE : return "pow";
case ELEMENT_ARGFMAX_TYPE : return "argfmax";
case ELEMENT_ARGMAX_TYPE : return "argmax";
case ELEMENT_ARGFMIN_TYPE : return "argfmin";
case ELEMENT_ARGMIN_TYPE : return "argmin";
case ELEMENT_POW_TYPE : return "pow";
//Arithmetic
case OPERATOR_MINUS_TYPE : return "-";
case OPERATOR_ASSIGN_TYPE : return "=";
case OPERATOR_INPLACE_ADD_TYPE : return "+=";
case OPERATOR_INPLACE_SUB_TYPE : return "-=";
case OPERATOR_ADD_TYPE : return "+";
case OPERATOR_SUB_TYPE : return "-";
case OPERATOR_MULT_TYPE : return "*";
case OPERATOR_ELEMENT_PROD_TYPE : return "*";
case OPERATOR_DIV_TYPE : return "/";
case OPERATOR_ELEMENT_DIV_TYPE : return "/";
case MINUS_TYPE : return "-";
case ASSIGN_TYPE : return "=";
case INPLACE_ADD_TYPE : return "+=";
case INPLACE_SUB_TYPE : return "-=";
case ADD_TYPE : return "+";
case SUB_TYPE : return "-";
case MULT_TYPE : return "*";
case ELEMENT_PROD_TYPE : return "*";
case DIV_TYPE : return "/";
case ELEMENT_DIV_TYPE : return "/";
//Relational
case OPERATOR_NEGATE_TYPE: return "!";
case OPERATOR_ELEMENT_EQ_TYPE : return "==";
case OPERATOR_ELEMENT_NEQ_TYPE : return "!=";
case OPERATOR_ELEMENT_GREATER_TYPE : return ">";
case OPERATOR_ELEMENT_GEQ_TYPE : return ">=";
case OPERATOR_ELEMENT_LESS_TYPE : return "<";
case OPERATOR_ELEMENT_LEQ_TYPE : return "<=";
case NEGATE_TYPE: return "!";
case ELEMENT_EQ_TYPE : return "==";
case ELEMENT_NEQ_TYPE : return "!=";
case ELEMENT_GREATER_TYPE : return ">";
case ELEMENT_GEQ_TYPE : return ">=";
case ELEMENT_LESS_TYPE : return "<";
case ELEMENT_LEQ_TYPE : return "<=";
case OPERATOR_ELEMENT_FMAX_TYPE : return "fmax";
case OPERATOR_ELEMENT_FMIN_TYPE : return "fmin";
case OPERATOR_ELEMENT_MAX_TYPE : return "max";
case OPERATOR_ELEMENT_MIN_TYPE : return "min";
case ELEMENT_FMAX_TYPE : return "fmax";
case ELEMENT_FMIN_TYPE : return "fmin";
case ELEMENT_MAX_TYPE : return "max";
case ELEMENT_MIN_TYPE : return "min";
//Binary
case OPERATOR_GEMM_NN_TYPE : return "prodNN";
case OPERATOR_GEMM_TN_TYPE : return "prodTN";
case OPERATOR_GEMM_NT_TYPE : return "prodNT";
case OPERATOR_GEMM_TT_TYPE : return "prodTT";
case OPERATOR_VDIAG_TYPE : return "vdiag";
case OPERATOR_MATRIX_DIAG_TYPE : return "mdiag";
case OPERATOR_MATRIX_ROW_TYPE : return "row";
case OPERATOR_MATRIX_COLUMN_TYPE : return "col";
case OPERATOR_PAIR_TYPE: return "pair";
case OPERATOR_ACCESS_INDEX_TYPE: return "access";
case MATRIX_PRODUCT_NN_TYPE : return "prodNN";
case MATRIX_PRODUCT_TN_TYPE : return "prodTN";
case MATRIX_PRODUCT_NT_TYPE : return "prodNT";
case MATRIX_PRODUCT_TT_TYPE : return "prodTT";
case VDIAG_TYPE : return "vdiag";
case MATRIX_DIAG_TYPE : return "mdiag";
case MATRIX_ROW_TYPE : return "row";
case MATRIX_COLUMN_TYPE : return "col";
case PAIR_TYPE: return "pair";
case ACCESS_INDEX_TYPE: return "access";
//FOR
case OPERATOR_SFOR_TYPE: return "sfor";
case SFOR_TYPE: return "sfor";
default : throw operation_not_supported_exception("Unsupported operator");
}
@@ -245,7 +245,7 @@ void evaluate_expression_traversal::call_before_expansion(isaac::math_expression
math_expression::node const & root_node = math_expression.tree()[root_idx];
if(detail::is_cast(root_node.op))
str_ += mapping_.at(std::make_pair(root_idx, PARENT_NODE_TYPE))->evaluate(accessors_);
else if (( (root_node.op.type_family==OPERATOR_UNARY_TYPE_FAMILY&&root_node.op.type!=OPERATOR_ADD_TYPE) || detail::is_elementwise_function(root_node.op))
else if (( (root_node.op.type_family==UNARY_TYPE_FAMILY&&root_node.op.type!=ADD_TYPE) || detail::is_elementwise_function(root_node.op))
&& !detail::is_node_leaf(root_node.op))
str_+=evaluate(root_node.op.type);
if(root_node.op.type!=OPERATOR_FUSE)
@@ -268,7 +268,7 @@ void evaluate_expression_traversal::operator()(isaac::math_expression const & ma
{
if (detail::is_node_leaf(root_node.op))
str_ += mapping_.at(key)->evaluate(accessors_);
else if(root_node.op.type_family!=OPERATOR_UNARY_TYPE_FAMILY)
else if(root_node.op.type_family!=UNARY_TYPE_FAMILY)
{
if (detail::is_elementwise_operator(root_node.op))
str_ += evaluate(root_node.op.type);
@@ -280,7 +280,7 @@ void evaluate_expression_traversal::operator()(isaac::math_expression const & ma
{
if (leaf==LHS_NODE_TYPE)
{
if (root_node.lhs.type_family!=COMPOSITE_OPERATOR_FAMILY)
if (root_node.lhs.subtype!=COMPOSITE_OPERATOR_TYPE)
{
if (root_node.lhs.subtype==FOR_LOOP_INDEX_TYPE)
str_ += "sforidx" + tools::to_string(root_node.lhs.for_idx.level);
@@ -291,7 +291,7 @@ void evaluate_expression_traversal::operator()(isaac::math_expression const & ma
if (leaf==RHS_NODE_TYPE)
{
if (root_node.rhs.type_family!=COMPOSITE_OPERATOR_FAMILY)
if (root_node.rhs.subtype!=COMPOSITE_OPERATOR_TYPE)
{
if (root_node.rhs.subtype==FOR_LOOP_INDEX_TYPE)
str_ += "sforidx" + tools::to_string(root_node.rhs.for_idx.level);
@@ -312,14 +312,14 @@ std::string evaluate(leaf_t leaf, std::map<std::string, std::string> const & acc
if (leaf==RHS_NODE_TYPE)
{
if (root_node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.rhs.node_index, traversal_functor, false);
else
traversal_functor(math_expression, root_idx, leaf);
}
else if (leaf==LHS_NODE_TYPE)
{
if (root_node.lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.lhs.node_index, traversal_functor, false);
else
traversal_functor(math_expression, root_idx, leaf);
@@ -369,14 +369,14 @@ void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::strin
if (leaf==RHS_NODE_TYPE)
{
if (root_node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.rhs.node_index, traversal_functor, true);
else
traversal_functor(math_expression, root_idx, leaf);
}
else if (leaf==LHS_NODE_TYPE)
{
if (root_node.lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.lhs.node_index, traversal_functor, true);
else
traversal_functor(math_expression, root_idx, leaf);
@@ -440,9 +440,9 @@ void math_expression_representation_functor::append(char*& p, const char * str)
void math_expression_representation_functor::operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf_t) const
{
math_expression::node const & root_node = math_expression.tree()[root_idx];
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.type_family != COMPOSITE_OPERATOR_FAMILY)
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
append(root_node.lhs, detail::is_assignment(root_node.op));
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.type_family != COMPOSITE_OPERATOR_FAMILY)
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.subtype != COMPOSITE_OPERATOR_TYPE)
append(root_node.rhs, false);
else if (leaf_t==PARENT_NODE_TYPE)
append_id(ptr_,root_node.op.type);

View File

@@ -39,11 +39,11 @@ bool base::requires_fallback(math_expression const & expression)
int_t base::vector_size(math_expression::node const & node)
{
if (node.op.type==OPERATOR_MATRIX_DIAG_TYPE)
if (node.op.type==MATRIX_DIAG_TYPE)
return std::min<int_t>(node.lhs.array->shape()[0], node.lhs.array->shape()[1]);
else if (node.op.type==OPERATOR_MATRIX_ROW_TYPE)
else if (node.op.type==MATRIX_ROW_TYPE)
return node.lhs.array->shape()[1];
else if (node.op.type==OPERATOR_MATRIX_COLUMN_TYPE)
else if (node.op.type==MATRIX_COLUMN_TYPE)
return node.lhs.array->shape()[0];
else
return node.lhs.array->shape().max();
@@ -52,12 +52,12 @@ int_t base::vector_size(math_expression::node const & node)
std::pair<int_t, int_t> base::matrix_size(math_expression::container_type const & tree, math_expression::node const & node)
{
if (node.op.type==OPERATOR_VDIAG_TYPE)
if (node.op.type==VDIAG_TYPE)
{
int_t size = node.lhs.array->shape()[0];
return std::make_pair(size,size);
}
else if(node.op.type==OPERATOR_REPEAT_TYPE)
else if(node.op.type==REPEAT_TYPE)
{
size_t rep0 = tuple_get(tree, node.rhs.node_index, 0);
size_t rep1 = tuple_get(tree, node.rhs.node_index, 1);

View File

@@ -75,7 +75,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, math_expre
math_expression::container_type const & tree = expressions.tree();
std::vector<std::size_t> sfors = filter_nodes([](math_expression::node const & node){return node.op.type==OPERATOR_SFOR_TYPE;}, expressions, expressions.root(), true);
std::vector<std::size_t> sfors = filter_nodes([](math_expression::node const & node){return node.op.type==SFOR_TYPE;}, expressions, expressions.root(), true);
for(unsigned int i = 0 ; i < sfors.size() ; ++i)
{

View File

@@ -81,11 +81,11 @@ public:
void set_arguments(lhs_rhs_element const & lhs_rhs, bool is_assigned) const
{
switch(lhs_rhs.type_family)
switch(lhs_rhs.subtype)
{
case VALUE_TYPE_FAMILY: return set_arguments(lhs_rhs.dtype, lhs_rhs.vscalar);
case ARRAY_TYPE_FAMILY: return set_arguments(lhs_rhs.array, is_assigned);
case PLACEHOLDER_TYPE_FAMILY: return;
case VALUE_SCALAR_TYPE: return set_arguments(lhs_rhs.dtype, lhs_rhs.vscalar);
case DENSE_ARRAY_TYPE: return set_arguments(lhs_rhs.array, is_assigned);
case FOR_LOOP_INDEX_TYPE: return;
default: throw std::runtime_error("Unrecognized type family");
}
}
@@ -93,9 +93,9 @@ public:
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t leaf_t) const
{
math_expression::node const & root_node = math_expression.tree()[root_idx];
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.type_family != COMPOSITE_OPERATOR_FAMILY)
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
set_arguments(root_node.lhs, detail::is_assignment(root_node.op));
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.type_family != COMPOSITE_OPERATOR_FAMILY)
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.subtype != COMPOSITE_OPERATOR_TYPE)
set_arguments(root_node.rhs, false);
}

View File

@@ -44,11 +44,11 @@ class map_functor : public traversal_functor
std::shared_ptr<mapped_object> create(lhs_rhs_element const & lhs_rhs, bool is_assigned = false) const
{
switch(lhs_rhs.type_family)
switch(lhs_rhs.subtype)
{
case VALUE_TYPE_FAMILY: return create(lhs_rhs.dtype, lhs_rhs.vscalar);
case ARRAY_TYPE_FAMILY: return create(lhs_rhs.array, is_assigned);
case PLACEHOLDER_TYPE_FAMILY: return std::shared_ptr<mapped_object>(new mapped_placeholder(lhs_rhs.for_idx.level));
case VALUE_SCALAR_TYPE: return create(lhs_rhs.dtype, lhs_rhs.vscalar);
case DENSE_ARRAY_TYPE: return create(lhs_rhs.array, is_assigned);
case FOR_LOOP_INDEX_TYPE: return std::shared_ptr<mapped_object>(new mapped_placeholder(lhs_rhs.for_idx.level));
default: throw "";
}
}
@@ -65,31 +65,31 @@ public:
mapping_type::key_type key(root_idx, leaf_t);
math_expression::node const & root_node = math_expression.tree()[root_idx];
if (leaf_t == LHS_NODE_TYPE && root_node.lhs.type_family != COMPOSITE_OPERATOR_FAMILY)
if (leaf_t == LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
mapping_.insert(mapping_type::value_type(key, create(root_node.lhs, detail::is_assignment(root_node.op))));
else if (leaf_t == RHS_NODE_TYPE && root_node.rhs.type_family != COMPOSITE_OPERATOR_FAMILY)
else if (leaf_t == RHS_NODE_TYPE && root_node.rhs.subtype != COMPOSITE_OPERATOR_TYPE)
mapping_.insert(mapping_type::value_type(key, create(root_node.rhs)));
else if ( leaf_t== PARENT_NODE_TYPE)
{
if (root_node.op.type==OPERATOR_VDIAG_TYPE)
if (root_node.op.type==VDIAG_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_vdiag>(&math_expression, root_idx, &mapping_)));
else if (root_node.op.type==OPERATOR_MATRIX_DIAG_TYPE)
else if (root_node.op.type==MATRIX_DIAG_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_diag>(&math_expression, root_idx, &mapping_)));
else if (root_node.op.type==OPERATOR_MATRIX_ROW_TYPE)
else if (root_node.op.type==MATRIX_ROW_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&math_expression, root_idx, &mapping_)));
else if (root_node.op.type==OPERATOR_MATRIX_COLUMN_TYPE)
else if (root_node.op.type==MATRIX_COLUMN_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&math_expression, root_idx, &mapping_)));
else if(root_node.op.type==OPERATOR_ACCESS_INDEX_TYPE)
else if(root_node.op.type==ACCESS_INDEX_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_array_access>(&math_expression, root_idx, &mapping_)));
else if (detail::is_scalar_reduce_1d(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_1d>(&math_expression, root_idx, &mapping_)));
else if (detail::is_vector_reduce_1d(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_2d>(&math_expression, root_idx, &mapping_)));
else if (root_node.op.type_family == OPERATOR_GEMM_TYPE_FAMILY)
else if (root_node.op.type_family == MATRIX_PRODUCT_TYPE_FAMILY)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_product>(&math_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_REPEAT_TYPE)
else if (root_node.op.type == REPEAT_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&math_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
else if (root_node.op.type == OUTER_PROD_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&math_expression, root_idx, &mapping_)));
else if (detail::is_cast(root_node.op))
mapping_.insert(mapping_type::value_type(key, std::shared_ptr<mapped_object>(new mapped_cast(root_node.op.type, binder_.get()))));

View File

@@ -25,10 +25,10 @@ inline void compute_index_reduce_1d(kernel_generation_stream & os, std::string a
// os << acc << " = " << cur_value << ">" << acc_value << "?" << cur << ":" << acc << ";" << std::endl;
os << acc << "= select(" << acc << "," << cur << "," << cur_value << ">" << acc_value << ");" << std::endl;
os << acc_value << "=";
if (op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE) os << "fmax";
if (op.type==OPERATOR_ELEMENT_ARGMAX_TYPE) os << "max";
if (op.type==OPERATOR_ELEMENT_ARGFMIN_TYPE) os << "fmin";
if (op.type==OPERATOR_ELEMENT_ARGMIN_TYPE) os << "min";
if (op.type==ELEMENT_ARGFMAX_TYPE) os << "fmax";
if (op.type==ELEMENT_ARGMAX_TYPE) os << "max";
if (op.type==ELEMENT_ARGFMIN_TYPE) os << "fmin";
if (op.type==ELEMENT_ARGMIN_TYPE) os << "min";
os << "(" << acc_value << "," << cur_value << ");"<< std::endl;
}
@@ -39,17 +39,17 @@ inline std::string neutral_element(op_element const & op, driver::backend_type b
switch (op.type)
{
case OPERATOR_ADD_TYPE : return "0";
case OPERATOR_MULT_TYPE : return "1";
case OPERATOR_DIV_TYPE : return "1";
case OPERATOR_ELEMENT_FMAX_TYPE : return N_INF;
case OPERATOR_ELEMENT_ARGFMAX_TYPE : return N_INF;
case OPERATOR_ELEMENT_MAX_TYPE : return N_INF;
case OPERATOR_ELEMENT_ARGMAX_TYPE : return N_INF;
case OPERATOR_ELEMENT_FMIN_TYPE : return INF;
case OPERATOR_ELEMENT_ARGFMIN_TYPE : return INF;
case OPERATOR_ELEMENT_MIN_TYPE : return INF;
case OPERATOR_ELEMENT_ARGMIN_TYPE : return INF;
case ADD_TYPE : return "0";
case MULT_TYPE : return "1";
case DIV_TYPE : return "1";
case ELEMENT_FMAX_TYPE : return N_INF;
case ELEMENT_ARGFMAX_TYPE : return N_INF;
case ELEMENT_MAX_TYPE : return N_INF;
case ELEMENT_ARGMAX_TYPE : return N_INF;
case ELEMENT_FMIN_TYPE : return INF;
case ELEMENT_ARGFMIN_TYPE : return INF;
case ELEMENT_MIN_TYPE : return INF;
case ELEMENT_ARGMIN_TYPE : return INF;
default: throw std::runtime_error("Unsupported reduce_1d operator : no neutral element known");
}
@@ -57,18 +57,18 @@ inline std::string neutral_element(op_element const & op, driver::backend_type b
inline bool is_reduce_1d(math_expression::node const & node)
{
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY;
return node.op.type_family==VECTOR_DOT_TYPE_FAMILY
|| node.op.type_family==COLUMNS_DOT_TYPE_FAMILY
|| node.op.type_family==ROWS_DOT_TYPE_FAMILY;
}
inline bool is_index_reduction(op_element const & op)
{
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
|| op.type==OPERATOR_ELEMENT_ARGMAX_TYPE
|| op.type==OPERATOR_ELEMENT_ARGFMIN_TYPE
|| op.type==OPERATOR_ELEMENT_ARGMIN_TYPE;
return op.type==ELEMENT_ARGFMAX_TYPE
|| op.type==ELEMENT_ARGMAX_TYPE
|| op.type==ELEMENT_ARGFMIN_TYPE
|| op.type==ELEMENT_ARGMIN_TYPE;
}

View File

@@ -32,27 +32,27 @@ namespace isaac
bool result = false;
switch(op.type_family)
{
case OPERATOR_UNARY_TYPE_FAMILY:
case OPERATOR_BINARY_TYPE_FAMILY:
case UNARY_TYPE_FAMILY:
case BINARY_TYPE_FAMILY:
result |= is_mmprod(expression)
|| (result |= expression==REDUCE_2D_ROWS && other==REDUCE_2D_COLS)
|| (result |= expression==REDUCE_2D_COLS && other==REDUCE_2D_ROWS);
break;
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
case VECTOR_DOT_TYPE_FAMILY:
result |= is_mvprod(expression)
|| expression==REDUCE_1D;
break;
case OPERATOR_ROWS_DOT_TYPE_FAMILY:
case ROWS_DOT_TYPE_FAMILY:
result |= is_mmprod(expression)
|| is_mvprod(expression)
|| expression==REDUCE_1D;
break;
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
case COLUMNS_DOT_TYPE_FAMILY:
result |= is_mmprod(expression)
|| is_mvprod(expression)
|| expression==REDUCE_1D;
break;
case OPERATOR_GEMM_TYPE_FAMILY:
case MATRIX_PRODUCT_TYPE_FAMILY:
result |= (is_mmprod(expression) && !is_first)
|| is_mvprod(expression)
|| expression==REDUCE_1D;
@@ -74,30 +74,30 @@ namespace isaac
{
switch(op.type_family)
{
case OPERATOR_UNARY_TYPE_FAMILY:
case UNARY_TYPE_FAMILY:
if(is_mmprod(left))
return ELEMENTWISE_2D;
return left;
case OPERATOR_BINARY_TYPE_FAMILY:
case BINARY_TYPE_FAMILY:
if(left == REDUCE_2D_ROWS || right == REDUCE_2D_ROWS) return REDUCE_2D_ROWS;
else if(left == REDUCE_2D_COLS || right == REDUCE_2D_COLS) return REDUCE_2D_COLS;
else if(left == REDUCE_1D || right == REDUCE_1D) return REDUCE_1D;
else if(left == ELEMENTWISE_2D || right == ELEMENTWISE_2D) return ELEMENTWISE_2D;
else if(left == ELEMENTWISE_1D || right == ELEMENTWISE_1D) return op.type==OPERATOR_OUTER_PROD_TYPE?ELEMENTWISE_2D:ELEMENTWISE_1D;
else if(left == ELEMENTWISE_1D || right == ELEMENTWISE_1D) return op.type==OUTER_PROD_TYPE?ELEMENTWISE_2D:ELEMENTWISE_1D;
else if(is_mmprod(left) || is_mmprod(right)) return ELEMENTWISE_2D;
else if(right == INVALID_EXPRESSION_TYPE) return left;
else if(left == INVALID_EXPRESSION_TYPE) return right;
throw;
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
case VECTOR_DOT_TYPE_FAMILY:
return REDUCE_1D;
case OPERATOR_ROWS_DOT_TYPE_FAMILY:
case ROWS_DOT_TYPE_FAMILY:
return REDUCE_2D_ROWS;
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
case COLUMNS_DOT_TYPE_FAMILY:
return REDUCE_2D_COLS;
case OPERATOR_GEMM_TYPE_FAMILY:
if(op.type==OPERATOR_GEMM_NN_TYPE) return MATRIX_PRODUCT_NN;
else if(op.type==OPERATOR_GEMM_TN_TYPE) return MATRIX_PRODUCT_TN;
else if(op.type==OPERATOR_GEMM_NT_TYPE) return MATRIX_PRODUCT_NT;
case MATRIX_PRODUCT_TYPE_FAMILY:
if(op.type==MATRIX_PRODUCT_NN_TYPE) return MATRIX_PRODUCT_NN;
else if(op.type==MATRIX_PRODUCT_TN_TYPE) return MATRIX_PRODUCT_TN;
else if(op.type==MATRIX_PRODUCT_NT_TYPE) return MATRIX_PRODUCT_NT;
else return MATRIX_PRODUCT_TT;
default:
throw;
@@ -115,11 +115,11 @@ namespace isaac
auto ng1 = [](shape_t const & shape){ size_t res = 0 ; for(size_t i = 0 ; i < shape.size() ; ++i) res += (shape[i] > 1); return res;};
//Left
expression_type type_left = INVALID_EXPRESSION_TYPE;
if (node.lhs.type_family == COMPOSITE_OPERATOR_FAMILY)
if (node.lhs.subtype == COMPOSITE_OPERATOR_TYPE)
parse(array, node.lhs.node_index, breakpoints, type_left, false);
else if(node.lhs.subtype == DENSE_ARRAY_TYPE)
{
if(node.op.type==OPERATOR_MATRIX_ROW_TYPE || node.op.type==OPERATOR_MATRIX_COLUMN_TYPE || ng1(node.lhs.array->shape())<=1)
if(node.op.type==MATRIX_ROW_TYPE || node.op.type==MATRIX_COLUMN_TYPE || ng1(node.lhs.array->shape())<=1)
type_left = ELEMENTWISE_1D;
else
type_left = ELEMENTWISE_2D;
@@ -127,11 +127,11 @@ namespace isaac
//Right
expression_type type_right = INVALID_EXPRESSION_TYPE;
if (node.rhs.type_family == COMPOSITE_OPERATOR_FAMILY)
if (node.rhs.subtype == COMPOSITE_OPERATOR_TYPE)
parse(array, node.rhs.node_index, breakpoints, type_right, false);
else if(node.rhs.subtype == DENSE_ARRAY_TYPE)
{
if(node.op.type==OPERATOR_MATRIX_ROW_TYPE || node.op.type==OPERATOR_MATRIX_COLUMN_TYPE || ng1(node.rhs.array->shape())<=1)
if(node.op.type==MATRIX_ROW_TYPE || node.op.type==MATRIX_COLUMN_TYPE || ng1(node.rhs.array->shape())<=1)
type_right = ELEMENTWISE_1D;
else
type_right = ELEMENTWISE_2D;
@@ -160,7 +160,7 @@ namespace isaac
std::vector<std::shared_ptr<array> > temporaries_;
expression_type final_type;
//GEMM
//MATRIX_PRODUCT
if(symbolic::preset::matrix_product::args args = symbolic::preset::matrix_product::check(tree, rootidx)){
final_type = args.type;
}
@@ -208,10 +208,10 @@ namespace isaac
}
temporaries_.push_back(tmp);
tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE;
tree[rootidx].op.type = ASSIGN_TYPE;
fill(tree[rootidx].lhs, (array&)*tmp);
tree[rootidx].rhs = *it->second;
tree[rootidx].rhs.type_family = it->second->type_family;
tree[rootidx].rhs.subtype = it->second->subtype;
//Execute
profile->execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));

View File

@@ -10,22 +10,19 @@ namespace isaac
void fill(lhs_rhs_element &x, invalid_node)
{
x.type_family = INVALID_TYPE_FAMILY;
x.subtype = INVALID_SUBTYPE;
x.dtype = INVALID_NUMERIC_TYPE;
}
void fill(lhs_rhs_element & x, std::size_t node_index)
{
x.type_family = COMPOSITE_OPERATOR_FAMILY;
x.subtype = INVALID_SUBTYPE;
x.subtype = COMPOSITE_OPERATOR_TYPE;
x.dtype = INVALID_NUMERIC_TYPE;
x.node_index = node_index;
}
void fill(lhs_rhs_element & x, for_idx_t index)
{
x.type_family = PLACEHOLDER_TYPE_FAMILY;
x.subtype = FOR_LOOP_INDEX_TYPE;
x.dtype = INVALID_NUMERIC_TYPE;
x.for_idx = index;
@@ -33,7 +30,6 @@ void fill(lhs_rhs_element & x, for_idx_t index)
void fill(lhs_rhs_element & x, array_base const & a)
{
x.type_family = ARRAY_TYPE_FAMILY;
x.subtype = DENSE_ARRAY_TYPE;
x.dtype = a.dtype();
x.array = (array_base*)&a;
@@ -41,7 +37,6 @@ void fill(lhs_rhs_element & x, array_base const & a)
void fill(lhs_rhs_element & x, value_scalar const & v)
{
x.type_family = VALUE_TYPE_FAMILY;
x.dtype = v.dtype();
x.subtype = VALUE_SCALAR_TYPE;
x.vscalar = v.values();
@@ -51,7 +46,7 @@ lhs_rhs_element::lhs_rhs_element(){}
//
op_element::op_element() {}
op_element::op_element(operation_node_type_family const & _type_family, operation_node_type const & _type) : type_family(_type_family), type(_type){}
op_element::op_element(operation_type_family const & _type_family, operation_type const & _type) : type_family(_type_family), type(_type){}
//
math_expression::math_expression(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op)
@@ -122,8 +117,8 @@ math_expression::math_expression(math_expression const & lhs, math_expression co
tree_[root_].op = op;
fill(tree_[root_].rhs, lsize + rhs.root_);
for(container_type::iterator it = tree_.begin() + lsize ; it != tree_.end() - 1 ; ++it){
if(it->lhs.type_family==COMPOSITE_OPERATOR_FAMILY) it->lhs.node_index+=lsize;
if(it->rhs.type_family==COMPOSITE_OPERATOR_FAMILY) it->rhs.node_index+=lsize;
if(it->lhs.subtype==COMPOSITE_OPERATOR_TYPE) it->lhs.node_index+=lsize;
if(it->rhs.subtype==COMPOSITE_OPERATOR_TYPE) it->rhs.node_index+=lsize;
}
root_ = tree_.size() - 1;
}
@@ -181,17 +176,17 @@ int_t math_expression::dim() const
//}
math_expression math_expression::operator-()
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), *context_, dtype_, shape_); }
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), *context_, dtype_, shape_); }
math_expression math_expression::operator!()
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_NEGATE_TYPE), *context_, INT_TYPE, shape_); }
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), *context_, INT_TYPE, shape_); }
//
math_expression::node const & lhs_most(math_expression::container_type const & array, math_expression::node const & init)
{
math_expression::node const * current = &init;
while (current->lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
while (current->lhs.subtype==COMPOSITE_OPERATOR_TYPE)
current = &array[current->lhs.node_index];
return *current;
}
@@ -200,8 +195,8 @@ math_expression::node const & lhs_most(math_expression::container_type const & a
{ return lhs_most(array, array[root]); }
//
math_expression for_idx_t::operator=(value_scalar const & r) const { return math_expression(*this, r, op_element(OPERATOR_BINARY_TYPE_FAMILY,OPERATOR_ASSIGN_TYPE), r.dtype()); }
math_expression for_idx_t::operator=(math_expression const & r) const { return math_expression(*this, r, op_element(OPERATOR_BINARY_TYPE_FAMILY,OPERATOR_ASSIGN_TYPE), r.context(), r.dtype(), r.shape()); }
math_expression for_idx_t::operator=(value_scalar const & r) const { return math_expression(*this, r, op_element(BINARY_TYPE_FAMILY,ASSIGN_TYPE), r.dtype()); }
math_expression for_idx_t::operator=(math_expression const & r) const { return math_expression(*this, r, op_element(BINARY_TYPE_FAMILY,ASSIGN_TYPE), r.context(), r.dtype(), r.shape()); }
math_expression for_idx_t::operator+=(value_scalar const & r) const { return *this = *this + r; }
math_expression for_idx_t::operator-=(value_scalar const & r) const { return *this = *this - r; }

View File

@@ -10,7 +10,7 @@ namespace isaac
#define ISAAC_MAP_TO_STRING(NAME) case NAME: return #NAME
inline std::string to_string(math_expression_node_subtype const & f)
inline std::string to_string(node_type const & f)
{
switch(f)
{
@@ -23,7 +23,7 @@ inline std::string to_string(math_expression_node_subtype const & f)
inline std::string to_string(lhs_rhs_element const & e)
{
if(e.type_family==COMPOSITE_OPERATOR_FAMILY)
if(e.subtype==COMPOSITE_OPERATOR_TYPE)
{
return"COMPOSITE [" + tools::to_string(e.node_index) + "]";
}
@@ -53,10 +53,10 @@ namespace detail
os << "Node " << node_index << ": " << current_node << std::endl;
if (current_node.lhs.type_family == COMPOSITE_OPERATOR_FAMILY)
if (current_node.lhs.subtype == COMPOSITE_OPERATOR_TYPE)
print_node(os, s, current_node.lhs.node_index, indent+1);
if (current_node.rhs.type_family == COMPOSITE_OPERATOR_FAMILY)
if (current_node.rhs.subtype == COMPOSITE_OPERATOR_TYPE)
print_node(os, s, current_node.rhs.node_index, indent+1);
}
}

View File

@@ -12,33 +12,33 @@ namespace preset
void matrix_product::handle_node(math_expression::container_type const & tree, size_t rootidx, args & a)
{
//Matrix-Matrix product node
if(tree[rootidx].op.type_family==OPERATOR_GEMM_TYPE_FAMILY)
if(tree[rootidx].op.type_family==MATRIX_PRODUCT_TYPE_FAMILY)
{
if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY) a.A = &tree[rootidx].lhs;
if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs;
if(tree[rootidx].lhs.subtype==DENSE_ARRAY_TYPE) a.A = &tree[rootidx].lhs;
if(tree[rootidx].rhs.subtype==DENSE_ARRAY_TYPE) a.B = &tree[rootidx].rhs;
switch(tree[rootidx].op.type)
{
case OPERATOR_GEMM_NN_TYPE: a.type = MATRIX_PRODUCT_NN; break;
case OPERATOR_GEMM_NT_TYPE: a.type = MATRIX_PRODUCT_NT; break;
case OPERATOR_GEMM_TN_TYPE: a.type = MATRIX_PRODUCT_TN; break;
case OPERATOR_GEMM_TT_TYPE: a.type = MATRIX_PRODUCT_TT; break;
case MATRIX_PRODUCT_NN_TYPE: a.type = MATRIX_PRODUCT_NN; break;
case MATRIX_PRODUCT_NT_TYPE: a.type = MATRIX_PRODUCT_NT; break;
case MATRIX_PRODUCT_TN_TYPE: a.type = MATRIX_PRODUCT_TN; break;
case MATRIX_PRODUCT_TT_TYPE: a.type = MATRIX_PRODUCT_TT; break;
default: break;
}
}
//Scalar multiplication node
if(tree[rootidx].op.type==OPERATOR_MULT_TYPE)
if(tree[rootidx].op.type==MULT_TYPE)
{
//alpha*PROD
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY
&& tree[tree[rootidx].rhs.node_index].op.type_family==OPERATOR_GEMM_TYPE_FAMILY)
if(tree[rootidx].lhs.subtype==VALUE_SCALAR_TYPE && tree[rootidx].rhs.subtype==COMPOSITE_OPERATOR_TYPE
&& tree[tree[rootidx].rhs.node_index].op.type_family==MATRIX_PRODUCT_TYPE_FAMILY)
{
a.alpha = value_scalar(tree[rootidx].lhs.vscalar, tree[rootidx].lhs.dtype);
handle_node(tree, tree[rootidx].rhs.node_index, a);
}
//beta*C
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY)
if(tree[rootidx].lhs.subtype==VALUE_SCALAR_TYPE && tree[rootidx].rhs.subtype==DENSE_ARRAY_TYPE)
{
a.beta = value_scalar(tree[rootidx].lhs.vscalar, tree[rootidx].lhs.dtype);
a.C = &tree[rootidx].rhs;
@@ -55,26 +55,26 @@ matrix_product::args matrix_product::check(math_expression::container_type const
return result;
result.alpha = value_scalar(1, dtype);
result.beta = value_scalar(0, dtype);
if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
if(tree[rootidx].rhs.subtype==COMPOSITE_OPERATOR_TYPE)
{
rootidx = tree[rootidx].rhs.node_index;
bool is_add = tree[rootidx].op.type==OPERATOR_ADD_TYPE;
bool is_sub = tree[rootidx].op.type==OPERATOR_SUB_TYPE;
bool is_add = tree[rootidx].op.type==ADD_TYPE;
bool is_sub = tree[rootidx].op.type==SUB_TYPE;
//Form X +- Y"
if(is_add || is_sub)
{
if(tree[rootidx].lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
if(tree[rootidx].lhs.subtype==COMPOSITE_OPERATOR_TYPE)
handle_node(tree, tree[rootidx].lhs.node_index, result);
else if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY)
else if(tree[rootidx].lhs.subtype==DENSE_ARRAY_TYPE)
{
result.C = &tree[rootidx].lhs;
result.beta = value_scalar(1, dtype);
result.alpha = value_scalar(is_add?1:-1, dtype);
}
if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
if(tree[rootidx].rhs.subtype==COMPOSITE_OPERATOR_TYPE)
handle_node(tree, tree[rootidx].rhs.node_index, result);
else if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY)
else if(tree[rootidx].rhs.subtype==DENSE_ARRAY_TYPE)
{
result.C = &tree[rootidx].rhs;
result.alpha = value_scalar(1, dtype);

View File

@@ -169,7 +169,7 @@ extern "C"
//BLAS3
//*****************
#define MAKE_GEMM(TYPE_CHAR, TYPE_ISAAC, TYPE_CL) \
#define MAKE_MATRIX_PRODUCT(TYPE_CHAR, TYPE_ISAAC, TYPE_CL) \
clblasStatus clblas ## TYPE_CHAR ## gemm(clblasOrder order, clblasTranspose transA, clblasTranspose transB,\
size_t M, size_t N, size_t K,\
TYPE_CL alpha, const cl_mem cmA, size_t offA, size_t lda,\
@@ -208,8 +208,8 @@ extern "C"
return clblasSuccess;\
}
MAKE_GEMM(S, sc::FLOAT_TYPE, cl_float)
MAKE_GEMM(D, sc::DOUBLE_TYPE, cl_double)
MAKE_MATRIX_PRODUCT(S, sc::FLOAT_TYPE, cl_float)
MAKE_MATRIX_PRODUCT(D, sc::DOUBLE_TYPE, cl_double)
#undef DOT

View File

@@ -194,7 +194,7 @@ extern "C"
//BLAS3
//*****************
#define MAKE_GEMM(TYPE_CHAR, TYPE_ISAAC, TYPE_CU) \
#define MAKE_MATRIX_PRODUCT(TYPE_CHAR, TYPE_ISAAC, TYPE_CU) \
void cublas ## TYPE_CHAR ## gemm (char transa, char transb, int m, int n, int k,\
TYPE_CU alpha, const TYPE_CU *A, int lda,\
const TYPE_CU *B, int ldb, TYPE_CU beta, TYPE_CU *C,\
@@ -235,6 +235,6 @@ extern "C"
return CUBLAS_STATUS_SUCCESS;\
}
MAKE_GEMM(S, sc::FLOAT_TYPE, cl_float)
MAKE_GEMM(D, sc::DOUBLE_TYPE, cl_double)
MAKE_MATRIX_PRODUCT(S, sc::FLOAT_TYPE, cl_float)
MAKE_MATRIX_PRODUCT(D, sc::DOUBLE_TYPE, cl_double)
}