Code Quality: heavy renaming and cleaning

This commit is contained in:
Philippe Tillet
2015-12-19 02:04:39 -05:00
parent b6d596d26d
commit bfa7504fc0
22 changed files with 502 additions and 518 deletions

View File

@@ -44,7 +44,7 @@ extern "C" {
* SWAP, SCAL, COPY, AXPY, DOT, DOTU, DOTC, ROTG, ROTMG, ROT, ROTM, iAMAX, ASUM and NRM2, * SWAP, SCAL, COPY, AXPY, DOT, DOTU, DOTC, ROTG, ROTMG, ROT, ROTM, iAMAX, ASUM and NRM2,
* BLAS-2 functions GEMV, SYMV, TRMV, TRSV, HEMV, SYR, SYR2, HER, HER2, GER, GERU, GERC, * BLAS-2 functions GEMV, SYMV, TRMV, TRSV, HEMV, SYR, SYR2, HER, HER2, GER, GERU, GERC,
* TPMV, SPMV, HPMV, TPSV, SPR, SPR2, HPR, HPR2, GBMV, TBMV, SBMV, HBMV and TBSV * TPMV, SPMV, HPMV, TPSV, SPR, SPR2, HPR, HPR2, GBMV, TBMV, SBMV, HBMV and TBSV
* and BLAS-3 functions GEMM, SYMM, TRMM, TRSM, HEMM, HERK, HER2K, SYRK and SYR2K. * and BLAS-3 functions MATRIX_PRODUCT, SYMM, TRMM, TRSM, HEMM, HERK, HER2K, SYRK and SYR2K.
* *
* This librarys primary goal is to assist the end user to enqueue OpenCL * This librarys primary goal is to assist the end user to enqueue OpenCL
* kernels to process BLAS functions in an OpenCL-efficient manner, while * kernels to process BLAS functions in an OpenCL-efficient manner, while
@@ -7314,7 +7314,7 @@ clblasZtbsv(
/*@}*/ /*@}*/
/** /**
* @defgroup GEMM GEMM - General matrix-matrix multiplication * @defgroup MATRIX_PRODUCT MATRIX_PRODUCT - General matrix-matrix multiplication
* @ingroup BLAS3 * @ingroup BLAS3
*/ */
/*@{*/ /*@{*/
@@ -7372,7 +7372,7 @@ clblasZtbsv(
* the size of the respective buffer object; * the size of the respective buffer object;
* - the same error codes as clblasSgemm() otherwise. * - the same error codes as clblasSgemm() otherwise.
* *
* @ingroup GEMM * @ingroup MATRIX_PRODUCT
*/ */
clblasStatus clblasStatus
clblasSgemm( clblasSgemm(
@@ -7453,7 +7453,7 @@ clblasSgemm(
* the size of the respective buffer object; * the size of the respective buffer object;
* - the same error codes as the clblasSgemm() function otherwise. * - the same error codes as the clblasSgemm() function otherwise.
* *
* @ingroup GEMM * @ingroup MATRIX_PRODUCT
*/ */
clblasStatus clblasStatus
clblasDgemm( clblasDgemm(
@@ -7527,7 +7527,7 @@ clblasDgemm(
* the size of the respective buffer object; * the size of the respective buffer object;
* - the same error codes as the clblasSgemm() function otherwise. * - the same error codes as the clblasSgemm() function otherwise.
* *
* @ingroup GEMM * @ingroup MATRIX_PRODUCT
*/ */
clblasStatus clblasStatus
clblasCgemm( clblasCgemm(
@@ -7603,7 +7603,7 @@ clblasCgemm(
* the size of the respective buffer object; * the size of the respective buffer object;
* - the same error codes as the clblasSgemm() function otherwise. * - the same error codes as the clblasSgemm() function otherwise.
* *
* @ingroup GEMM * @ingroup MATRIX_PRODUCT
*/ */
clblasStatus clblasStatus
clblasZgemm( clblasZgemm(

View File

@@ -418,7 +418,7 @@ void CUBLASWINAPI cublasZhpr2 (char uplo, int n, cuDoubleComplex alpha,
const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, const cuDoubleComplex *x, int incx, const cuDoubleComplex *y,
int incy, cuDoubleComplex *AP); int incy, cuDoubleComplex *AP);
/* ------------------------BLAS3 Functions ------------------------------- */ /* ------------------------BLAS3 Functions ------------------------------- */
/* GEMM */ /* MATRIX_PRODUCT */
void CUBLASWINAPI cublasSgemm (char transa, char transb, int m, int n, int k, void CUBLASWINAPI cublasSgemm (char transa, char transb, int m, int n, int k,
float alpha, const float *A, int lda, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, const float *B, int ldb, float beta, float *C,

View File

@@ -1508,7 +1508,7 @@ CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhpr2_v2 (cublasHandle_t handle,
/* ---------------- CUBLAS BLAS3 functions ---------------- */ /* ---------------- CUBLAS BLAS3 functions ---------------- */
/* GEMM */ /* MATRIX_PRODUCT */
CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemm_v2 (cublasHandle_t handle, CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemm_v2 (cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transa,
cublasOperation_t transb, cublasOperation_t transb,
@@ -2042,7 +2042,7 @@ CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrmm_v2(cublasHandle_t handle, cubl
int ldb, int ldb,
cuDoubleComplex *C, cuDoubleComplex *C,
int ldc); int ldc);
/* BATCH GEMM */ /* BATCH MATRIX_PRODUCT */
CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmBatched (cublasHandle_t handle, CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmBatched (cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transa,
cublasOperation_t transb, cublasOperation_t transb,

View File

@@ -248,9 +248,9 @@ public:
class mapped_cast : public mapped_object class mapped_cast : public mapped_object
{ {
static std::string operator_to_str(operation_node_type type); static std::string operator_to_str(operation_type type);
public: public:
mapped_cast(operation_node_type type, unsigned int id); mapped_cast(operation_type type, unsigned int id);
}; };
extern mapped_object& get(math_expression::container_type const &, size_t, mapping_type const &, size_t); extern mapped_object& get(math_expression::container_type const &, size_t, mapping_type const &, size_t);

View File

@@ -47,9 +47,9 @@ inline void traverse(isaac::math_expression const & math_expression, std::size_t
//Lhs: //Lhs:
if (recurse) if (recurse)
{ {
if (root_node.lhs.type_family==COMPOSITE_OPERATOR_FAMILY) if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.lhs.node_index, fun, inspect); traverse(math_expression, root_node.lhs.node_index, fun, inspect);
if (root_node.lhs.type_family != INVALID_TYPE_FAMILY) if (root_node.lhs.subtype != INVALID_SUBTYPE)
fun(math_expression, root_idx, LHS_NODE_TYPE); fun(math_expression, root_idx, LHS_NODE_TYPE);
} }
@@ -58,11 +58,11 @@ inline void traverse(isaac::math_expression const & math_expression, std::size_t
fun(math_expression, root_idx, PARENT_NODE_TYPE); fun(math_expression, root_idx, PARENT_NODE_TYPE);
//Rhs: //Rhs:
if (recurse && root_node.rhs.type_family!=INVALID_TYPE_FAMILY) if (recurse && root_node.rhs.subtype!=INVALID_SUBTYPE)
{ {
if (root_node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY) if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.rhs.node_index, fun, inspect); traverse(math_expression, root_node.rhs.node_index, fun, inspect);
if (root_node.rhs.type_family != INVALID_TYPE_FAMILY) if (root_node.rhs.subtype != INVALID_SUBTYPE)
fun(math_expression, root_idx, RHS_NODE_TYPE); fun(math_expression, root_idx, RHS_NODE_TYPE);
} }
@@ -84,10 +84,10 @@ private:
class filter_elements_fun : public traversal_functor class filter_elements_fun : public traversal_functor
{ {
public: public:
filter_elements_fun(math_expression_node_subtype subtype, std::vector<lhs_rhs_element> & out); filter_elements_fun(node_type subtype, std::vector<lhs_rhs_element> & out);
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t) const; void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t) const;
private: private:
math_expression_node_subtype subtype_; node_type subtype_;
std::vector<lhs_rhs_element> & out_; std::vector<lhs_rhs_element> & out_;
}; };
@@ -96,9 +96,9 @@ std::vector<size_t> filter_nodes(bool (*pred)(math_expression::node const & node
size_t root, size_t root,
bool inspect); bool inspect);
std::vector<lhs_rhs_element> filter_elements(math_expression_node_subtype subtype, std::vector<lhs_rhs_element> filter_elements(node_type subtype,
isaac::math_expression const & math_expression); isaac::math_expression const & math_expression);
const char * evaluate(operation_node_type type); const char * evaluate(operation_type type);
/** @brief functor for generating the expression string from a math_expression */ /** @brief functor for generating the expression string from a math_expression */
class evaluate_expression_traversal: public traversal_functor class evaluate_expression_traversal: public traversal_functor

View File

@@ -23,140 +23,121 @@ namespace isaac
class array_base; class array_base;
/** @brief Optimization enum for grouping operations into unary or binary operations. Just for optimization of lookups. */ /** @brief Optimization enum for grouping operations into unary or binary operations. Just for optimization of lookups. */
enum operation_node_type_family enum operation_type_family
{ {
OPERATOR_INVALID_TYPE_FAMILY = 0, INVALID_TYPE_FAMILY = 0,
// BLAS1-type // BLAS1-type
OPERATOR_UNARY_TYPE_FAMILY, UNARY_TYPE_FAMILY,
OPERATOR_BINARY_TYPE_FAMILY, BINARY_TYPE_FAMILY,
OPERATOR_VECTOR_DOT_TYPE_FAMILY, VECTOR_DOT_TYPE_FAMILY,
// BLAS2-type // BLAS2-type
OPERATOR_ROWS_DOT_TYPE_FAMILY, ROWS_DOT_TYPE_FAMILY,
OPERATOR_COLUMNS_DOT_TYPE_FAMILY, COLUMNS_DOT_TYPE_FAMILY,
// BLAS3-type // BLAS3-type
OPERATOR_GEMM_TYPE_FAMILY MATRIX_PRODUCT_TYPE_FAMILY
}; };
/** @brief Enumeration for identifying the possible operations */ /** @brief Enumeration for identifying the possible operations */
enum operation_node_type enum operation_type
{ {
OPERATOR_INVALID_TYPE = 0, INVALID_TYPE = 0,
// unary operator // unary operator
OPERATOR_MINUS_TYPE, MINUS_TYPE,
OPERATOR_NEGATE_TYPE, NEGATE_TYPE,
// unary expression // unary expression
OPERATOR_CAST_BOOL_TYPE, CAST_BOOL_TYPE,
OPERATOR_CAST_CHAR_TYPE, CAST_CHAR_TYPE,
OPERATOR_CAST_UCHAR_TYPE, CAST_UCHAR_TYPE,
OPERATOR_CAST_SHORT_TYPE, CAST_SHORT_TYPE,
OPERATOR_CAST_USHORT_TYPE, CAST_USHORT_TYPE,
OPERATOR_CAST_INT_TYPE, CAST_INT_TYPE,
OPERATOR_CAST_UINT_TYPE, CAST_UINT_TYPE,
OPERATOR_CAST_LONG_TYPE, CAST_LONG_TYPE,
OPERATOR_CAST_ULONG_TYPE, CAST_ULONG_TYPE,
OPERATOR_CAST_HALF_TYPE, CAST_HALF_TYPE,
OPERATOR_CAST_FLOAT_TYPE, CAST_FLOAT_TYPE,
OPERATOR_CAST_DOUBLE_TYPE, CAST_DOUBLE_TYPE,
OPERATOR_ABS_TYPE, ABS_TYPE,
OPERATOR_ACOS_TYPE, ACOS_TYPE,
OPERATOR_ASIN_TYPE, ASIN_TYPE,
OPERATOR_ATAN_TYPE, ATAN_TYPE,
OPERATOR_CEIL_TYPE, CEIL_TYPE,
OPERATOR_COS_TYPE, COS_TYPE,
OPERATOR_COSH_TYPE, COSH_TYPE,
OPERATOR_EXP_TYPE, EXP_TYPE,
OPERATOR_FABS_TYPE, FABS_TYPE,
OPERATOR_FLOOR_TYPE, FLOOR_TYPE,
OPERATOR_LOG_TYPE, LOG_TYPE,
OPERATOR_LOG10_TYPE, LOG10_TYPE,
OPERATOR_SIN_TYPE, SIN_TYPE,
OPERATOR_SINH_TYPE, SINH_TYPE,
OPERATOR_SQRT_TYPE, SQRT_TYPE,
OPERATOR_TAN_TYPE, TAN_TYPE,
OPERATOR_TANH_TYPE, TANH_TYPE,
OPERATOR_TRANS_TYPE, TRANS_TYPE,
// binary expression // binary expression
OPERATOR_ASSIGN_TYPE, ASSIGN_TYPE,
OPERATOR_INPLACE_ADD_TYPE, INPLACE_ADD_TYPE,
OPERATOR_INPLACE_SUB_TYPE, INPLACE_SUB_TYPE,
OPERATOR_ADD_TYPE, ADD_TYPE,
OPERATOR_SUB_TYPE, SUB_TYPE,
OPERATOR_MULT_TYPE, MULT_TYPE,
OPERATOR_DIV_TYPE, DIV_TYPE,
OPERATOR_ELEMENT_ARGFMAX_TYPE, ELEMENT_ARGFMAX_TYPE,
OPERATOR_ELEMENT_ARGFMIN_TYPE, ELEMENT_ARGFMIN_TYPE,
OPERATOR_ELEMENT_ARGMAX_TYPE, ELEMENT_ARGMAX_TYPE,
OPERATOR_ELEMENT_ARGMIN_TYPE, ELEMENT_ARGMIN_TYPE,
OPERATOR_ELEMENT_PROD_TYPE, ELEMENT_PROD_TYPE,
OPERATOR_ELEMENT_DIV_TYPE, ELEMENT_DIV_TYPE,
OPERATOR_ELEMENT_EQ_TYPE, ELEMENT_EQ_TYPE,
OPERATOR_ELEMENT_NEQ_TYPE, ELEMENT_NEQ_TYPE,
OPERATOR_ELEMENT_GREATER_TYPE, ELEMENT_GREATER_TYPE,
OPERATOR_ELEMENT_GEQ_TYPE, ELEMENT_GEQ_TYPE,
OPERATOR_ELEMENT_LESS_TYPE, ELEMENT_LESS_TYPE,
OPERATOR_ELEMENT_LEQ_TYPE, ELEMENT_LEQ_TYPE,
OPERATOR_ELEMENT_POW_TYPE, ELEMENT_POW_TYPE,
OPERATOR_ELEMENT_FMAX_TYPE, ELEMENT_FMAX_TYPE,
OPERATOR_ELEMENT_FMIN_TYPE, ELEMENT_FMIN_TYPE,
OPERATOR_ELEMENT_MAX_TYPE, ELEMENT_MAX_TYPE,
OPERATOR_ELEMENT_MIN_TYPE, ELEMENT_MIN_TYPE,
//Products //Products
OPERATOR_OUTER_PROD_TYPE, OUTER_PROD_TYPE,
OPERATOR_GEMM_NN_TYPE, MATRIX_PRODUCT_NN_TYPE,
OPERATOR_GEMM_TN_TYPE, MATRIX_PRODUCT_TN_TYPE,
OPERATOR_GEMM_NT_TYPE, MATRIX_PRODUCT_NT_TYPE,
OPERATOR_GEMM_TT_TYPE, MATRIX_PRODUCT_TT_TYPE,
//Access modifiers //Access modifiers
OPERATOR_MATRIX_DIAG_TYPE, MATRIX_DIAG_TYPE,
OPERATOR_MATRIX_ROW_TYPE, MATRIX_ROW_TYPE,
OPERATOR_MATRIX_COLUMN_TYPE, MATRIX_COLUMN_TYPE,
OPERATOR_REPEAT_TYPE, REPEAT_TYPE,
OPERATOR_RESHAPE_TYPE, RESHAPE_TYPE,
OPERATOR_SHIFT_TYPE, SHIFT_TYPE,
OPERATOR_VDIAG_TYPE, VDIAG_TYPE,
OPERATOR_ACCESS_INDEX_TYPE, ACCESS_INDEX_TYPE,
OPERATOR_PAIR_TYPE, PAIR_TYPE,
OPERATOR_FUSE, OPERATOR_FUSE,
OPERATOR_SFOR_TYPE, SFOR_TYPE,
};
/** @brief Groups the type of a node in the math_expression tree. Used for faster dispatching */
enum math_expression_node_type_family
{
INVALID_TYPE_FAMILY = 0,
COMPOSITE_OPERATOR_FAMILY,
VALUE_TYPE_FAMILY,
ARRAY_TYPE_FAMILY,
PLACEHOLDER_TYPE_FAMILY
};
/** @brief Encodes the type of a node in the math_expression tree. */
enum math_expression_node_subtype
{
INVALID_SUBTYPE = 0,
VALUE_SCALAR_TYPE,
DENSE_ARRAY_TYPE,
FOR_LOOP_INDEX_TYPE
}; };
struct op_element struct op_element
{ {
op_element(); op_element();
op_element(operation_node_type_family const & _type_family, operation_node_type const & _type); op_element(operation_type_family const & _type_family, operation_type const & _type);
operation_node_type_family type_family; operation_type_family type_family;
operation_node_type type; operation_type type;
}; };
struct for_idx_t struct for_idx_t
@@ -172,11 +153,19 @@ struct for_idx_t
int level; int level;
}; };
enum node_type
{
INVALID_SUBTYPE = 0,
COMPOSITE_OPERATOR_TYPE,
VALUE_SCALAR_TYPE,
DENSE_ARRAY_TYPE,
FOR_LOOP_INDEX_TYPE
};
struct lhs_rhs_element struct lhs_rhs_element
{ {
lhs_rhs_element(); lhs_rhs_element();
math_expression_node_type_family type_family; node_type subtype;
math_expression_node_subtype subtype;
numeric_type dtype; numeric_type dtype;
union union
{ {

View File

@@ -7,7 +7,7 @@
namespace isaac namespace isaac
{ {
std::string to_string(math_expression_node_subtype const & f); std::string to_string(node_type const & f);
std::string to_string(lhs_rhs_element const & e); std::string to_string(lhs_rhs_element const & e);
std::ostream & operator<<(std::ostream & os, math_expression::node const & s_node); std::ostream & operator<<(std::ostream & os, math_expression::node const & s_node);
std::string to_string(isaac::math_expression const & s); std::string to_string(isaac::math_expression const & s);

View File

@@ -19,13 +19,13 @@ ISAACAPI typename std::conditional<std::is_arithmetic<T>::value, value_scalar, T
template<typename T, typename... Args> template<typename T, typename... Args>
ISAACAPI math_expression make_tuple(driver::Context const & context, T const & x, Args... args) ISAACAPI math_expression make_tuple(driver::Context const & context, T const & x, Args... args)
{ return math_expression(wrap_generic(x), make_tuple(context, args...), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_PAIR_TYPE), context, numeric_type_of(x), {1}); } { return math_expression(wrap_generic(x), make_tuple(context, args...), op_element(BINARY_TYPE_FAMILY, PAIR_TYPE), context, numeric_type_of(x), {1}); }
inline value_scalar tuple_get(math_expression::container_type const & tree, size_t root, size_t idx) inline value_scalar tuple_get(math_expression::container_type const & tree, size_t root, size_t idx)
{ {
for(unsigned int i = 0 ; i < idx ; ++i){ for(unsigned int i = 0 ; i < idx ; ++i){
math_expression::node node = tree[root]; math_expression::node node = tree[root];
if(node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY) if(node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
root = node.rhs.node_index; root = node.rhs.node_index;
else else
return value_scalar(node.rhs.vscalar, node.rhs.dtype); return value_scalar(node.rhs.vscalar, node.rhs.dtype);

View File

@@ -169,7 +169,7 @@ array_base & array_base::operator=(array_base const & rhs)
{ {
if(shape_.min()==0) return *this; if(shape_.min()==0) return *this;
assert(dtype_ == rhs.dtype()); assert(dtype_ == rhs.dtype());
math_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_); math_expression expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
execute(execution_handler(expression)); execute(execution_handler(expression));
return *this; return *this;
} }
@@ -178,7 +178,7 @@ array_base & array_base::operator=(value_scalar const & rhs)
{ {
if(shape_.min()==0) return *this; if(shape_.min()==0) return *this;
assert(dtype_ == rhs.dtype()); assert(dtype_ == rhs.dtype());
math_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_); math_expression expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
execute(execution_handler(expression)); execute(execution_handler(expression));
return *this; return *this;
} }
@@ -188,7 +188,7 @@ array_base& array_base::operator=(execution_handler const & c)
{ {
if(shape_.min()==0) return *this; if(shape_.min()==0) return *this;
assert(dtype_ == c.x().dtype()); assert(dtype_ == c.x().dtype());
math_expression expression(*this, c.x(), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_); math_expression expression(*this, c.x(), op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options())); execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
return *this; return *this;
} }
@@ -228,53 +228,53 @@ INSTANTIATE(double);
math_expression array_base::operator-() math_expression array_base::operator-()
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); } { return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
math_expression array_base::operator!() math_expression array_base::operator!()
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_NEGATE_TYPE), context_, INT_TYPE, shape_); } { return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), context_, INT_TYPE, shape_); }
// //
array_base & array_base::operator+=(value_scalar const & rhs) array_base & array_base::operator+=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); } { return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator+=(array_base const & rhs) array_base & array_base::operator+=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); } { return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator+=(math_expression const & rhs) array_base & array_base::operator+=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), rhs.context(), dtype_, shape_); } { return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), rhs.context(), dtype_, shape_); }
//---- //----
array_base & array_base::operator-=(value_scalar const & rhs) array_base & array_base::operator-=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); } { return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator-=(array_base const & rhs) array_base & array_base::operator-=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); } { return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator-=(math_expression const & rhs) array_base & array_base::operator-=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), rhs.context(), dtype_, shape_); } { return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), rhs.context(), dtype_, shape_); }
//---- //----
array_base & array_base::operator*=(value_scalar const & rhs) array_base & array_base::operator*=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); } { return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator*=(array_base const & rhs) array_base & array_base::operator*=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); } { return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator*=(math_expression const & rhs) array_base & array_base::operator*=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), rhs.context(), dtype_, shape_); } { return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), rhs.context(), dtype_, shape_); }
//---- //----
array_base & array_base::operator/=(value_scalar const & rhs) array_base & array_base::operator/=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); } { return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator/=(array_base const & rhs) array_base & array_base::operator/=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); } { return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator/=(math_expression const & rhs) array_base & array_base::operator/=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), rhs.context(), dtype_, shape_); } { return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), rhs.context(), dtype_, shape_); }
/*--- Indexing operators -----*/ /*--- Indexing operators -----*/
//--------------------------------------- //---------------------------------------
math_expression array_base::operator[](for_idx_t idx) const math_expression array_base::operator[](for_idx_t idx) const
{ {
return math_expression(*this, idx, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ACCESS_INDEX_TYPE), context_, dtype_, {1}); return math_expression(*this, idx, op_element(BINARY_TYPE_FAMILY, ACCESS_INDEX_TYPE), context_, dtype_, {1});
} }
scalar array_base::operator [](int_t idx) scalar array_base::operator [](int_t idx)
@@ -512,72 +512,72 @@ shape_t broadcast(shape_t const & a, shape_t const & b)
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \ #define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \
math_expression OPNAME (array_base const & x, math_expression const & y) \ math_expression OPNAME (array_base const & x, math_expression const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \ { return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
\ \
math_expression OPNAME (array_base const & x, array_base const & y) \ math_expression OPNAME (array_base const & x, array_base const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); }\ { return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); }\
\ \
math_expression OPNAME (array_base const & x, value_scalar const & y) \ math_expression OPNAME (array_base const & x, value_scalar const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\ { return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\ \
math_expression OPNAME (array_base const & x, for_idx_t const & y) \ math_expression OPNAME (array_base const & x, for_idx_t const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\ { return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\ \
\ \
math_expression OPNAME (math_expression const & x, math_expression const & y) \ math_expression OPNAME (math_expression const & x, math_expression const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \ { return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
\ \
math_expression OPNAME (math_expression const & x, array_base const & y) \ math_expression OPNAME (math_expression const & x, array_base const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \ { return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
\ \
math_expression OPNAME (math_expression const & x, value_scalar const & y) \ math_expression OPNAME (math_expression const & x, value_scalar const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \ { return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\ \
math_expression OPNAME (math_expression const & x, for_idx_t const & y) \ math_expression OPNAME (math_expression const & x, for_idx_t const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \ { return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\ \
\ \
math_expression OPNAME (value_scalar const & y, math_expression const & x) \ math_expression OPNAME (value_scalar const & y, math_expression const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \ { return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\ \
math_expression OPNAME (value_scalar const & y, array_base const & x) \ math_expression OPNAME (value_scalar const & y, array_base const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\ { return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\ \
math_expression OPNAME (value_scalar const & x, for_idx_t const & y) \ math_expression OPNAME (value_scalar const & x, for_idx_t const & y) \
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE); }\ { return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); }\
\ \
\ \
math_expression OPNAME (for_idx_t const & y, math_expression const & x) \ math_expression OPNAME (for_idx_t const & y, math_expression const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \ { return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\ \
math_expression OPNAME (for_idx_t const & y, value_scalar const & x) \ math_expression OPNAME (for_idx_t const & y, value_scalar const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE); } \ { return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); } \
\ \
math_expression OPNAME (for_idx_t const & y, array_base const & x) \ math_expression OPNAME (for_idx_t const & y, array_base const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\ { return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\ \
math_expression OPNAME (for_idx_t const & y, for_idx_t const & x) \ math_expression OPNAME (for_idx_t const & y, for_idx_t const & x) \
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP)); } { return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP)); }
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ADD_TYPE, operator +, x.dtype()) DEFINE_ELEMENT_BINARY_OPERATOR(ADD_TYPE, operator +, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_SUB_TYPE, operator -, x.dtype()) DEFINE_ELEMENT_BINARY_OPERATOR(SUB_TYPE, operator -, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_MULT_TYPE, operator *, x.dtype()) DEFINE_ELEMENT_BINARY_OPERATOR(MULT_TYPE, operator *, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_DIV_TYPE, operator /, x.dtype()) DEFINE_ELEMENT_BINARY_OPERATOR(DIV_TYPE, operator /, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, maximum, x.dtype()) DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_MAX_TYPE, maximum, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, minimum, x.dtype()) DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_MIN_TYPE, minimum, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow, x.dtype()) DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_POW_TYPE, pow, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ASSIGN_TYPE, assign, x.dtype()) DEFINE_ELEMENT_BINARY_OPERATOR(ASSIGN_TYPE, assign, x.dtype())
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GREATER_TYPE, operator >, INT_TYPE) DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_GREATER_TYPE, operator >, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GEQ_TYPE, operator >=, INT_TYPE) DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_GEQ_TYPE, operator >=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LESS_TYPE, operator <, INT_TYPE) DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_LESS_TYPE, operator <, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LEQ_TYPE, operator <=, INT_TYPE) DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_LEQ_TYPE, operator <=, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==, INT_TYPE) DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_EQ_TYPE, operator ==, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=, INT_TYPE) DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_NEQ_TYPE, operator !=, INT_TYPE)
#define DEFINE_OUTER(LTYPE, RTYPE) \ #define DEFINE_OUTER(LTYPE, RTYPE) \
math_expression outer(LTYPE const & x, RTYPE const & y)\ math_expression outer(LTYPE const & x, RTYPE const & y)\
@@ -585,7 +585,7 @@ math_expression outer(LTYPE const & x, RTYPE const & y)\
assert(x.dim()<=1 && y.dim()<=1);\ assert(x.dim()<=1 && y.dim()<=1);\
if(x.dim()<1 || y.dim()<1)\ if(x.dim()<1 || y.dim()<1)\
return x*y;\ return x*y;\
return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_OUTER_PROD_TYPE), x.context(), x.dtype(), {x.shape().max(), y.shape().max()} );\ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OUTER_PROD_TYPE), x.context(), x.dtype(), {x.shape().max(), y.shape().max()} );\
}\ }\
DEFINE_OUTER(array_base, array_base) DEFINE_OUTER(array_base, array_base)
@@ -622,61 +622,61 @@ DEFINE_ROT(math_expression, math_expression, math_expression, math_expression)
//--------------------------------------- //---------------------------------------
#define DEFINE_ELEMENT_UNARY_OPERATOR(OP, OPNAME) \ #define DEFINE_ELEMENT_UNARY_OPERATOR(OP, OPNAME) \
math_expression OPNAME (array_base const & x) \ math_expression OPNAME (array_base const & x) \
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\ { return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
\ \
math_expression OPNAME (math_expression const & x) \ math_expression OPNAME (math_expression const & x) \
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); } { return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?OPERATOR_FABS_TYPE:OPERATOR_ABS_TYPE, abs) DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?FABS_TYPE:ABS_TYPE, abs)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ACOS_TYPE, acos) DEFINE_ELEMENT_UNARY_OPERATOR(ACOS_TYPE, acos)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ASIN_TYPE, asin) DEFINE_ELEMENT_UNARY_OPERATOR(ASIN_TYPE, asin)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ATAN_TYPE, atan) DEFINE_ELEMENT_UNARY_OPERATOR(ATAN_TYPE, atan)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_CEIL_TYPE, ceil) DEFINE_ELEMENT_UNARY_OPERATOR(CEIL_TYPE, ceil)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_COS_TYPE, cos) DEFINE_ELEMENT_UNARY_OPERATOR(COS_TYPE, cos)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_COSH_TYPE, cosh) DEFINE_ELEMENT_UNARY_OPERATOR(COSH_TYPE, cosh)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_EXP_TYPE, exp) DEFINE_ELEMENT_UNARY_OPERATOR(EXP_TYPE, exp)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_FLOOR_TYPE, floor) DEFINE_ELEMENT_UNARY_OPERATOR(FLOOR_TYPE, floor)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_LOG_TYPE, log) DEFINE_ELEMENT_UNARY_OPERATOR(LOG_TYPE, log)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_LOG10_TYPE,log10) DEFINE_ELEMENT_UNARY_OPERATOR(LOG10_TYPE,log10)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_SIN_TYPE, sin) DEFINE_ELEMENT_UNARY_OPERATOR(SIN_TYPE, sin)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_SINH_TYPE, sinh) DEFINE_ELEMENT_UNARY_OPERATOR(SINH_TYPE, sinh)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_SQRT_TYPE, sqrt) DEFINE_ELEMENT_UNARY_OPERATOR(SQRT_TYPE, sqrt)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_TAN_TYPE, tan) DEFINE_ELEMENT_UNARY_OPERATOR(TAN_TYPE, tan)
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_TANH_TYPE, tanh) DEFINE_ELEMENT_UNARY_OPERATOR(TANH_TYPE, tanh)
#undef DEFINE_ELEMENT_UNARY_OPERATOR #undef DEFINE_ELEMENT_UNARY_OPERATOR
//--------------------------------------- //---------------------------------------
///*--- Misc----*/ ///*--- Misc----*/
////--------------------------------------- ////---------------------------------------
inline operation_node_type casted(numeric_type dtype) inline operation_type casted(numeric_type dtype)
{ {
switch(dtype) switch(dtype)
{ {
// case BOOL_TYPE: return OPERATOR_CAST_BOOL_TYPE; // case BOOL_TYPE: return CAST_BOOL_TYPE;
case CHAR_TYPE: return OPERATOR_CAST_CHAR_TYPE; case CHAR_TYPE: return CAST_CHAR_TYPE;
case UCHAR_TYPE: return OPERATOR_CAST_UCHAR_TYPE; case UCHAR_TYPE: return CAST_UCHAR_TYPE;
case SHORT_TYPE: return OPERATOR_CAST_SHORT_TYPE; case SHORT_TYPE: return CAST_SHORT_TYPE;
case USHORT_TYPE: return OPERATOR_CAST_USHORT_TYPE; case USHORT_TYPE: return CAST_USHORT_TYPE;
case INT_TYPE: return OPERATOR_CAST_INT_TYPE; case INT_TYPE: return CAST_INT_TYPE;
case UINT_TYPE: return OPERATOR_CAST_UINT_TYPE; case UINT_TYPE: return CAST_UINT_TYPE;
case LONG_TYPE: return OPERATOR_CAST_LONG_TYPE; case LONG_TYPE: return CAST_LONG_TYPE;
case ULONG_TYPE: return OPERATOR_CAST_ULONG_TYPE; case ULONG_TYPE: return CAST_ULONG_TYPE;
// case FLOAT_TYPE: return OPERATOR_CAST_HALF_TYPE; // case FLOAT_TYPE: return CAST_HALF_TYPE;
case FLOAT_TYPE: return OPERATOR_CAST_FLOAT_TYPE; case FLOAT_TYPE: return CAST_FLOAT_TYPE;
case DOUBLE_TYPE: return OPERATOR_CAST_DOUBLE_TYPE; case DOUBLE_TYPE: return CAST_DOUBLE_TYPE;
default: throw unknown_datatype(dtype); default: throw unknown_datatype(dtype);
} }
} }
math_expression cast(array_base const & x, numeric_type dtype) math_expression cast(array_base const & x, numeric_type dtype)
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); } { return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
math_expression cast(math_expression const & x, numeric_type dtype) math_expression cast(math_expression const & x, numeric_type dtype)
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); } { return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx) isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return math_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, {M, N}); } { return math_expression(value_scalar(1), value_scalar(0), op_element(UNARY_TYPE_FAMILY, VDIAG_TYPE), ctx, dtype, {M, N}); }
array diag(array_base & x, int offset) array diag(array_base & x, int offset)
{ {
@@ -689,7 +689,7 @@ array diag(array_base & x, int offset)
isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx) isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, {M, N}); } { return math_expression(value_scalar(0, dtype), invalid_node(), op_element(UNARY_TYPE_FAMILY, ADD_TYPE), ctx, dtype, {M, N}); }
inline shape_t flip(shape_t const & shape) inline shape_t flip(shape_t const & shape)
{ {
@@ -703,28 +703,28 @@ inline shape_t flip(shape_t const & shape)
//{ return size4(shape1[0]*shape2[0], shape1[1]*shape2[1]);} //{ return size4(shape1[0]*shape2[0], shape1[1]*shape2[1]);}
math_expression trans(array_base const & x) \ math_expression trans(array_base const & x) \
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\ { return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
\ \
math_expression trans(math_expression const & x) \ math_expression trans(math_expression const & x) \
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); } { return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
math_expression repmat(array_base const & A, int_t const & rep1, int_t const & rep2) math_expression repmat(array_base const & A, int_t const & rep1, int_t const & rep2)
{ {
int_t sub1 = A.shape()[0]; int_t sub1 = A.shape()[0];
int_t sub2 = A.dim()==2?A.shape()[1]:1; int_t sub2 = A.dim()==2?A.shape()[1]:1;
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2}); return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
} }
math_expression repmat(math_expression const & A, int_t const & rep1, int_t const & rep2) math_expression repmat(math_expression const & A, int_t const & rep1, int_t const & rep2)
{ {
int_t sub1 = A.shape()[0]; int_t sub1 = A.shape()[0];
int_t sub2 = A.dim()==2?A.shape()[1]:1; int_t sub2 = A.dim()==2?A.shape()[1]:1;
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2}); return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
} }
#define DEFINE_ACCESS_ROW(TYPEA, TYPEB) \ #define DEFINE_ACCESS_ROW(TYPEA, TYPEB) \
math_expression row(TYPEA const & x, TYPEB const & i)\ math_expression row(TYPEA const & x, TYPEB const & i)\
{ return math_expression(x, i, op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_MATRIX_ROW_TYPE), x.context(), x.dtype(), {x.shape()[1]}); } { return math_expression(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_ROW_TYPE), x.context(), x.dtype(), {x.shape()[1]}); }
DEFINE_ACCESS_ROW(array_base, value_scalar) DEFINE_ACCESS_ROW(array_base, value_scalar)
DEFINE_ACCESS_ROW(array_base, for_idx_t) DEFINE_ACCESS_ROW(array_base, for_idx_t)
@@ -736,7 +736,7 @@ DEFINE_ACCESS_ROW(math_expression, math_expression)
#define DEFINE_ACCESS_COL(TYPEA, TYPEB) \ #define DEFINE_ACCESS_COL(TYPEA, TYPEB) \
math_expression col(TYPEA const & x, TYPEB const & i)\ math_expression col(TYPEA const & x, TYPEB const & i)\
{ return math_expression(x, i, op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_MATRIX_COLUMN_TYPE), x.context(), x.dtype(), {x.shape()[0]}); } { return math_expression(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_COLUMN_TYPE), x.context(), x.dtype(), {x.shape()[0]}); }
DEFINE_ACCESS_COL(array_base, value_scalar) DEFINE_ACCESS_COL(array_base, value_scalar)
DEFINE_ACCESS_COL(array_base, for_idx_t) DEFINE_ACCESS_COL(array_base, for_idx_t)
@@ -756,11 +756,11 @@ math_expression OPNAME(array_base const & x, int_t axis)\
if(axis < -1 || axis > x.dim())\ if(axis < -1 || axis > x.dim())\
throw std::out_of_range("The axis entry is out of bounds");\ throw std::out_of_range("The axis entry is out of bounds");\
else if(axis==-1)\ else if(axis==-1)\
return math_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\ return math_expression(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
else if(axis==0)\ else if(axis==0)\
return math_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\ return math_expression(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
else\ else\
return math_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\ return math_expression(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
}\ }\
\ \
math_expression OPNAME(math_expression const & x, int_t axis)\ math_expression OPNAME(math_expression const & x, int_t axis)\
@@ -768,18 +768,18 @@ math_expression OPNAME(math_expression const & x, int_t axis)\
if(axis < -1 || axis > x.dim())\ if(axis < -1 || axis > x.dim())\
throw std::out_of_range("The axis entry is out of bounds");\ throw std::out_of_range("The axis entry is out of bounds");\
if(axis==-1)\ if(axis==-1)\
return math_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\ return math_expression(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
else if(axis==0)\ else if(axis==0)\
return math_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\ return math_expression(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
else\ else\
return math_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\ return math_expression(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
} }
DEFINE_REDUCTION(OPERATOR_ADD_TYPE, sum) DEFINE_REDUCTION(ADD_TYPE, sum)
DEFINE_REDUCTION(OPERATOR_ELEMENT_ARGMAX_TYPE, argmax) DEFINE_REDUCTION(ELEMENT_ARGMAX_TYPE, argmax)
DEFINE_REDUCTION(OPERATOR_ELEMENT_MAX_TYPE, max) DEFINE_REDUCTION(ELEMENT_MAX_TYPE, max)
DEFINE_REDUCTION(OPERATOR_ELEMENT_MIN_TYPE, min) DEFINE_REDUCTION(ELEMENT_MIN_TYPE, min)
DEFINE_REDUCTION(OPERATOR_ELEMENT_ARGMIN_TYPE, argmin) DEFINE_REDUCTION(ELEMENT_ARGMIN_TYPE, argmin)
#undef DEFINE_REDUCTION #undef DEFINE_REDUCTION
@@ -789,21 +789,21 @@ namespace detail
math_expression matmatprod(array_base const & A, array_base const & B) math_expression matmatprod(array_base const & A, array_base const & B)
{ {
shape_t shape{A.shape()[0], B.shape()[1]}; shape_t shape{A.shape()[0], B.shape()[1]};
return math_expression(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, OPERATOR_GEMM_NN_TYPE), A.context(), A.dtype(), shape); return math_expression(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, MATRIX_PRODUCT_NN_TYPE), A.context(), A.dtype(), shape);
} }
math_expression matmatprod(math_expression const & A, array_base const & B) math_expression matmatprod(math_expression const & A, array_base const & B)
{ {
operation_node_type type = OPERATOR_GEMM_NN_TYPE; operation_type type = MATRIX_PRODUCT_NN_TYPE;
shape_t shape{A.shape()[0], B.shape()[1]}; shape_t shape{A.shape()[0], B.shape()[1]};
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]); math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE; bool A_trans = A_root.op.type==TRANS_TYPE;
if(A_trans){ if(A_trans){
type = OPERATOR_GEMM_TN_TYPE; type = MATRIX_PRODUCT_TN_TYPE;
} }
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape); math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]); math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs; if(A_trans) res_root.lhs = A_root.lhs;
return res; return res;
@@ -811,16 +811,16 @@ namespace detail
math_expression matmatprod(array_base const & A, math_expression const & B) math_expression matmatprod(array_base const & A, math_expression const & B)
{ {
operation_node_type type = OPERATOR_GEMM_NN_TYPE; operation_type type = MATRIX_PRODUCT_NN_TYPE;
shape_t shape{A.shape()[0], B.shape()[1]}; shape_t shape{A.shape()[0], B.shape()[1]};
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]); math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE; bool B_trans = B_root.op.type==TRANS_TYPE;
if(B_trans){ if(B_trans){
type = OPERATOR_GEMM_NT_TYPE; type = MATRIX_PRODUCT_NT_TYPE;
} }
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape); math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]); math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
if(B_trans) res_root.rhs = B_root.lhs; if(B_trans) res_root.rhs = B_root.lhs;
return res; return res;
@@ -828,20 +828,20 @@ namespace detail
math_expression matmatprod(math_expression const & A, math_expression const & B) math_expression matmatprod(math_expression const & A, math_expression const & B)
{ {
operation_node_type type = OPERATOR_GEMM_NN_TYPE; operation_type type = MATRIX_PRODUCT_NN_TYPE;
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]); math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]); math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
shape_t shape{A.shape()[0], B.shape()[1]}; shape_t shape{A.shape()[0], B.shape()[1]};
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE; bool A_trans = A_root.op.type==TRANS_TYPE;
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE; bool B_trans = B_root.op.type==TRANS_TYPE;
if(A_trans && B_trans) type = OPERATOR_GEMM_TT_TYPE; if(A_trans && B_trans) type = MATRIX_PRODUCT_TT_TYPE;
else if(A_trans && !B_trans) type = OPERATOR_GEMM_TN_TYPE; else if(A_trans && !B_trans) type = MATRIX_PRODUCT_TN_TYPE;
else if(!A_trans && B_trans) type = OPERATOR_GEMM_NT_TYPE; else if(!A_trans && B_trans) type = MATRIX_PRODUCT_NT_TYPE;
else type = OPERATOR_GEMM_NN_TYPE; else type = MATRIX_PRODUCT_NN_TYPE;
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape); math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]); math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs; if(A_trans) res_root.lhs = A_root.lhs;
if(B_trans) res_root.rhs = B_root.lhs; if(B_trans) res_root.rhs = B_root.lhs;
@@ -862,14 +862,14 @@ namespace detail
int_t M = A.shape()[0]; int_t M = A.shape()[0];
int_t N = A.shape()[1]; int_t N = A.shape()[1];
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]); math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE; bool A_trans = A_root.op.type==TRANS_TYPE;
while(A_root.lhs.type_family==COMPOSITE_OPERATOR_FAMILY){ while(A_root.lhs.subtype==COMPOSITE_OPERATOR_TYPE){
A_root = A.tree()[A_root.lhs.node_index]; A_root = A.tree()[A_root.lhs.node_index];
A_trans ^= A_root.op.type==OPERATOR_TRANS_TYPE; A_trans ^= A_root.op.type==TRANS_TYPE;
} }
if(A_trans) if(A_trans)
{ {
math_expression tmp(A, repmat(x, 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), A.context(), A.dtype(), {N, M}); math_expression tmp(A, repmat(x, 1, M), op_element(BINARY_TYPE_FAMILY, ELEMENT_PROD_TYPE), A.context(), A.dtype(), {N, M});
//Remove trans //Remove trans
tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs; tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs;
return sum(tmp, 0); return sum(tmp, 0);
@@ -890,10 +890,10 @@ ISAACAPI void swap(view x, view y)
//Reshape //Reshape
math_expression reshape(array_base const & x, shape_t const & shape) math_expression reshape(array_base const & x, shape_t const & shape)
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_RESHAPE_TYPE), x.context(), x.dtype(), shape); } { return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
math_expression reshape(math_expression const & x, shape_t const & shape) math_expression reshape(math_expression const & x, shape_t const & shape)
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_RESHAPE_TYPE), x.context(), x.dtype(), shape); } { return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
math_expression ravel(array_base const & x) math_expression ravel(array_base const & x)
{ return reshape(x, {x.shape().prod()}); } { return reshape(x, {x.shape().prod()}); }
@@ -908,7 +908,7 @@ math_expression dot(LTYPE const & x, RTYPE const & y)\
if(x.dim()==2 && x.shape()[1]==0)\ if(x.dim()==2 && x.shape()[1]==0)\
return zeros(x.shape()[0], y.shape()[1], dtype, context);\ return zeros(x.shape()[0], y.shape()[1], dtype, context);\
if(x.shape()[0]==0 || (y.dim()==2 && y.shape()[1]==0))\ if(x.shape()[0]==0 || (y.dim()==2 && y.shape()[1]==0))\
return math_expression(invalid_node(), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_INVALID_TYPE), context, dtype, {0});\ return math_expression(invalid_node(), invalid_node(), op_element(UNARY_TYPE_FAMILY, INVALID_TYPE), context, dtype, {0});\
if(x.dim()==1 && y.dim()==1)\ if(x.dim()==1 && y.dim()==1)\
return sum(x*y);\ return sum(x*y);\
if(x.dim()==2 && x.shape()[0]==1 && y.dim()==1){\ if(x.dim()==2 && x.shape()[0]==1 && y.dim()==1){\
@@ -967,13 +967,13 @@ DEFINE_NORM(math_expression)
math_expression fuse(math_expression const & x, math_expression const & y) math_expression fuse(math_expression const & x, math_expression const & y)
{ {
assert(x.context()==y.context()); assert(x.context()==y.context());
return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape()); return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape());
} }
/*--- For loops ---*/ /*--- For loops ---*/
ISAACAPI math_expression sfor(math_expression const & start, math_expression const & end, math_expression const & inc, math_expression const & x) ISAACAPI math_expression sfor(math_expression const & start, math_expression const & end, math_expression const & inc, math_expression const & x)
{ {
return math_expression(x, make_tuple(x.context(), start, end, inc), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SFOR_TYPE), x.context(), x.dtype(), x.shape()); return math_expression(x, make_tuple(x.context(), start, end, inc), op_element(UNARY_TYPE_FAMILY, SFOR_TYPE), x.context(), x.dtype(), x.shape());
} }

View File

@@ -96,7 +96,7 @@ mapped_object& get(math_expression::container_type const & tree, size_t root, ma
{ {
for(unsigned int i = 0 ; i < idx ; ++i){ for(unsigned int i = 0 ; i < idx ; ++i){
math_expression::node node = tree[root]; math_expression::node node = tree[root];
if(node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY) if(node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
root = node.rhs.node_index; root = node.rhs.node_index;
else else
return *(mapping.at(std::make_pair(root, RHS_NODE_TYPE))); return *(mapping.at(std::make_pair(root, RHS_NODE_TYPE)));
@@ -136,10 +136,10 @@ math_expression::node mapped_reduce::root_node() const
bool mapped_reduce::is_index_reduction() const bool mapped_reduce::is_index_reduction() const
{ {
op_element const & op = root_op(); op_element const & op = root_op();
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE return op.type==ELEMENT_ARGFMAX_TYPE
|| op.type==OPERATOR_ELEMENT_ARGMAX_TYPE || op.type==ELEMENT_ARGMAX_TYPE
|| op.type==OPERATOR_ELEMENT_ARGFMIN_TYPE || op.type==ELEMENT_ARGFMIN_TYPE
|| op.type==OPERATOR_ELEMENT_ARGMIN_TYPE; || op.type==ELEMENT_ARGMIN_TYPE;
} }
op_element mapped_reduce::root_op() const op_element mapped_reduce::root_op() const
@@ -358,27 +358,27 @@ void mapped_outer::postprocess(std::string &res) const
mapped_outer::mapped_outer(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "outer"), binary_leaf(info) mapped_outer::mapped_outer(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "outer"), binary_leaf(info)
{ } { }
std::string mapped_cast::operator_to_str(operation_node_type type) std::string mapped_cast::operator_to_str(operation_type type)
{ {
switch(type) switch(type)
{ {
case OPERATOR_CAST_BOOL_TYPE : return "bool"; case CAST_BOOL_TYPE : return "bool";
case OPERATOR_CAST_CHAR_TYPE : return "char"; case CAST_CHAR_TYPE : return "char";
case OPERATOR_CAST_UCHAR_TYPE : return "uchar"; case CAST_UCHAR_TYPE : return "uchar";
case OPERATOR_CAST_SHORT_TYPE : return "short"; case CAST_SHORT_TYPE : return "short";
case OPERATOR_CAST_USHORT_TYPE : return "ushort"; case CAST_USHORT_TYPE : return "ushort";
case OPERATOR_CAST_INT_TYPE : return "int"; case CAST_INT_TYPE : return "int";
case OPERATOR_CAST_UINT_TYPE : return "uint"; case CAST_UINT_TYPE : return "uint";
case OPERATOR_CAST_LONG_TYPE : return "long"; case CAST_LONG_TYPE : return "long";
case OPERATOR_CAST_ULONG_TYPE : return "ulong"; case CAST_ULONG_TYPE : return "ulong";
case OPERATOR_CAST_HALF_TYPE : return "half"; case CAST_HALF_TYPE : return "half";
case OPERATOR_CAST_FLOAT_TYPE : return "float"; case CAST_FLOAT_TYPE : return "float";
case OPERATOR_CAST_DOUBLE_TYPE : return "double"; case CAST_DOUBLE_TYPE : return "double";
default : return "invalid"; default : return "invalid";
} }
} }
mapped_cast::mapped_cast(operation_node_type type, unsigned int id) : mapped_object(operator_to_str(type), id, "cast") mapped_cast::mapped_cast(operation_type type, unsigned int id) : mapped_object(operator_to_str(type), id, "cast")
{ } { }

View File

@@ -16,103 +16,103 @@ namespace detail
bool is_scalar_reduce_1d(math_expression::node const & node) bool is_scalar_reduce_1d(math_expression::node const & node)
{ {
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY; return node.op.type_family==VECTOR_DOT_TYPE_FAMILY;
} }
bool is_vector_reduce_1d(math_expression::node const & node) bool is_vector_reduce_1d(math_expression::node const & node)
{ {
return node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY return node.op.type_family==ROWS_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY; || node.op.type_family==COLUMNS_DOT_TYPE_FAMILY;
} }
bool is_assignment(op_element const & op) bool is_assignment(op_element const & op)
{ {
return op.type== OPERATOR_ASSIGN_TYPE return op.type== ASSIGN_TYPE
|| op.type== OPERATOR_INPLACE_ADD_TYPE || op.type== INPLACE_ADD_TYPE
|| op.type== OPERATOR_INPLACE_SUB_TYPE; || op.type== INPLACE_SUB_TYPE;
} }
bool is_elementwise_operator(op_element const & op) bool is_elementwise_operator(op_element const & op)
{ {
return is_assignment(op) return is_assignment(op)
|| op.type== OPERATOR_ADD_TYPE || op.type== ADD_TYPE
|| op.type== OPERATOR_SUB_TYPE || op.type== SUB_TYPE
|| op.type== OPERATOR_ELEMENT_PROD_TYPE || op.type== ELEMENT_PROD_TYPE
|| op.type== OPERATOR_ELEMENT_DIV_TYPE || op.type== ELEMENT_DIV_TYPE
|| op.type== OPERATOR_MULT_TYPE || op.type== MULT_TYPE
|| op.type== OPERATOR_DIV_TYPE || op.type== DIV_TYPE
|| op.type== OPERATOR_ELEMENT_EQ_TYPE || op.type== ELEMENT_EQ_TYPE
|| op.type== OPERATOR_ELEMENT_NEQ_TYPE || op.type== ELEMENT_NEQ_TYPE
|| op.type== OPERATOR_ELEMENT_GREATER_TYPE || op.type== ELEMENT_GREATER_TYPE
|| op.type== OPERATOR_ELEMENT_LESS_TYPE || op.type== ELEMENT_LESS_TYPE
|| op.type== OPERATOR_ELEMENT_GEQ_TYPE || op.type== ELEMENT_GEQ_TYPE
|| op.type== OPERATOR_ELEMENT_LEQ_TYPE ; || op.type== ELEMENT_LEQ_TYPE ;
} }
bool bypass(op_element const & op) bool bypass(op_element const & op)
{ {
return op.type == OPERATOR_RESHAPE_TYPE return op.type == RESHAPE_TYPE
||op.type == OPERATOR_TRANS_TYPE; ||op.type == TRANS_TYPE;
} }
bool is_cast(op_element const & op) bool is_cast(op_element const & op)
{ {
return op.type== OPERATOR_CAST_BOOL_TYPE return op.type== CAST_BOOL_TYPE
|| op.type== OPERATOR_CAST_CHAR_TYPE || op.type== CAST_CHAR_TYPE
|| op.type== OPERATOR_CAST_UCHAR_TYPE || op.type== CAST_UCHAR_TYPE
|| op.type== OPERATOR_CAST_SHORT_TYPE || op.type== CAST_SHORT_TYPE
|| op.type== OPERATOR_CAST_USHORT_TYPE || op.type== CAST_USHORT_TYPE
|| op.type== OPERATOR_CAST_INT_TYPE || op.type== CAST_INT_TYPE
|| op.type== OPERATOR_CAST_UINT_TYPE || op.type== CAST_UINT_TYPE
|| op.type== OPERATOR_CAST_LONG_TYPE || op.type== CAST_LONG_TYPE
|| op.type== OPERATOR_CAST_ULONG_TYPE || op.type== CAST_ULONG_TYPE
|| op.type== OPERATOR_CAST_FLOAT_TYPE || op.type== CAST_FLOAT_TYPE
|| op.type== OPERATOR_CAST_DOUBLE_TYPE || op.type== CAST_DOUBLE_TYPE
; ;
} }
bool is_node_leaf(op_element const & op) bool is_node_leaf(op_element const & op)
{ {
return op.type==OPERATOR_MATRIX_DIAG_TYPE return op.type==MATRIX_DIAG_TYPE
|| op.type==OPERATOR_VDIAG_TYPE || op.type==VDIAG_TYPE
|| op.type==OPERATOR_REPEAT_TYPE || op.type==REPEAT_TYPE
|| op.type==OPERATOR_MATRIX_ROW_TYPE || op.type==MATRIX_ROW_TYPE
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE || op.type==MATRIX_COLUMN_TYPE
|| op.type==OPERATOR_ACCESS_INDEX_TYPE || op.type==ACCESS_INDEX_TYPE
|| op.type==OPERATOR_OUTER_PROD_TYPE || op.type==OUTER_PROD_TYPE
|| op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY || op.type_family==VECTOR_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY || op.type_family==ROWS_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY || op.type_family==COLUMNS_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_GEMM_TYPE_FAMILY || op.type_family==MATRIX_PRODUCT_TYPE_FAMILY
; ;
} }
bool is_elementwise_function(op_element const & op) bool is_elementwise_function(op_element const & op)
{ {
return is_cast(op) return is_cast(op)
|| op.type== OPERATOR_ABS_TYPE || op.type== ABS_TYPE
|| op.type== OPERATOR_ACOS_TYPE || op.type== ACOS_TYPE
|| op.type== OPERATOR_ASIN_TYPE || op.type== ASIN_TYPE
|| op.type== OPERATOR_ATAN_TYPE || op.type== ATAN_TYPE
|| op.type== OPERATOR_CEIL_TYPE || op.type== CEIL_TYPE
|| op.type== OPERATOR_COS_TYPE || op.type== COS_TYPE
|| op.type== OPERATOR_COSH_TYPE || op.type== COSH_TYPE
|| op.type== OPERATOR_EXP_TYPE || op.type== EXP_TYPE
|| op.type== OPERATOR_FABS_TYPE || op.type== FABS_TYPE
|| op.type== OPERATOR_FLOOR_TYPE || op.type== FLOOR_TYPE
|| op.type== OPERATOR_LOG_TYPE || op.type== LOG_TYPE
|| op.type== OPERATOR_LOG10_TYPE || op.type== LOG10_TYPE
|| op.type== OPERATOR_SIN_TYPE || op.type== SIN_TYPE
|| op.type== OPERATOR_SINH_TYPE || op.type== SINH_TYPE
|| op.type== OPERATOR_SQRT_TYPE || op.type== SQRT_TYPE
|| op.type== OPERATOR_TAN_TYPE || op.type== TAN_TYPE
|| op.type== OPERATOR_TANH_TYPE || op.type== TANH_TYPE
|| op.type== OPERATOR_ELEMENT_POW_TYPE || op.type== ELEMENT_POW_TYPE
|| op.type== OPERATOR_ELEMENT_FMAX_TYPE || op.type== ELEMENT_FMAX_TYPE
|| op.type== OPERATOR_ELEMENT_FMIN_TYPE || op.type== ELEMENT_FMIN_TYPE
|| op.type== OPERATOR_ELEMENT_MAX_TYPE || op.type== ELEMENT_MAX_TYPE
|| op.type== OPERATOR_ELEMENT_MIN_TYPE; || op.type== ELEMENT_MIN_TYPE;
} }
@@ -139,7 +139,7 @@ std::vector<size_t> filter_nodes(bool (*pred)(math_expression::node const & node
} }
// //
filter_elements_fun::filter_elements_fun(math_expression_node_subtype subtype, std::vector<lhs_rhs_element> & out) : filter_elements_fun::filter_elements_fun(node_type subtype, std::vector<lhs_rhs_element> & out) :
subtype_(subtype), out_(out) subtype_(subtype), out_(out)
{ } { }
@@ -153,84 +153,84 @@ void filter_elements_fun::operator()(isaac::math_expression const & math_express
} }
std::vector<lhs_rhs_element> filter_elements(math_expression_node_subtype subtype, isaac::math_expression const & math_expression) std::vector<lhs_rhs_element> filter_elements(node_type subtype, isaac::math_expression const & math_expression)
{ {
std::vector<lhs_rhs_element> res; std::vector<lhs_rhs_element> res;
traverse(math_expression, math_expression.root(), filter_elements_fun(subtype, res), true); traverse(math_expression, math_expression.root(), filter_elements_fun(subtype, res), true);
return res; return res;
} }
/** @brief generate a string from an operation_node_type */ /** @brief generate a string from an operation_type */
const char * evaluate(operation_node_type type) const char * evaluate(operation_type type)
{ {
// unary expression // unary expression
switch (type) switch (type)
{ {
//Function //Function
case OPERATOR_ABS_TYPE : return "abs"; case ABS_TYPE : return "abs";
case OPERATOR_ACOS_TYPE : return "acos"; case ACOS_TYPE : return "acos";
case OPERATOR_ASIN_TYPE : return "asin"; case ASIN_TYPE : return "asin";
case OPERATOR_ATAN_TYPE : return "atan"; case ATAN_TYPE : return "atan";
case OPERATOR_CEIL_TYPE : return "ceil"; case CEIL_TYPE : return "ceil";
case OPERATOR_COS_TYPE : return "cos"; case COS_TYPE : return "cos";
case OPERATOR_COSH_TYPE : return "cosh"; case COSH_TYPE : return "cosh";
case OPERATOR_EXP_TYPE : return "exp"; case EXP_TYPE : return "exp";
case OPERATOR_FABS_TYPE : return "fabs"; case FABS_TYPE : return "fabs";
case OPERATOR_FLOOR_TYPE : return "floor"; case FLOOR_TYPE : return "floor";
case OPERATOR_LOG_TYPE : return "log"; case LOG_TYPE : return "log";
case OPERATOR_LOG10_TYPE : return "log10"; case LOG10_TYPE : return "log10";
case OPERATOR_SIN_TYPE : return "sin"; case SIN_TYPE : return "sin";
case OPERATOR_SINH_TYPE : return "sinh"; case SINH_TYPE : return "sinh";
case OPERATOR_SQRT_TYPE : return "sqrt"; case SQRT_TYPE : return "sqrt";
case OPERATOR_TAN_TYPE : return "tan"; case TAN_TYPE : return "tan";
case OPERATOR_TANH_TYPE : return "tanh"; case TANH_TYPE : return "tanh";
case OPERATOR_ELEMENT_ARGFMAX_TYPE : return "argfmax"; case ELEMENT_ARGFMAX_TYPE : return "argfmax";
case OPERATOR_ELEMENT_ARGMAX_TYPE : return "argmax"; case ELEMENT_ARGMAX_TYPE : return "argmax";
case OPERATOR_ELEMENT_ARGFMIN_TYPE : return "argfmin"; case ELEMENT_ARGFMIN_TYPE : return "argfmin";
case OPERATOR_ELEMENT_ARGMIN_TYPE : return "argmin"; case ELEMENT_ARGMIN_TYPE : return "argmin";
case OPERATOR_ELEMENT_POW_TYPE : return "pow"; case ELEMENT_POW_TYPE : return "pow";
//Arithmetic //Arithmetic
case OPERATOR_MINUS_TYPE : return "-"; case MINUS_TYPE : return "-";
case OPERATOR_ASSIGN_TYPE : return "="; case ASSIGN_TYPE : return "=";
case OPERATOR_INPLACE_ADD_TYPE : return "+="; case INPLACE_ADD_TYPE : return "+=";
case OPERATOR_INPLACE_SUB_TYPE : return "-="; case INPLACE_SUB_TYPE : return "-=";
case OPERATOR_ADD_TYPE : return "+"; case ADD_TYPE : return "+";
case OPERATOR_SUB_TYPE : return "-"; case SUB_TYPE : return "-";
case OPERATOR_MULT_TYPE : return "*"; case MULT_TYPE : return "*";
case OPERATOR_ELEMENT_PROD_TYPE : return "*"; case ELEMENT_PROD_TYPE : return "*";
case OPERATOR_DIV_TYPE : return "/"; case DIV_TYPE : return "/";
case OPERATOR_ELEMENT_DIV_TYPE : return "/"; case ELEMENT_DIV_TYPE : return "/";
//Relational //Relational
case OPERATOR_NEGATE_TYPE: return "!"; case NEGATE_TYPE: return "!";
case OPERATOR_ELEMENT_EQ_TYPE : return "=="; case ELEMENT_EQ_TYPE : return "==";
case OPERATOR_ELEMENT_NEQ_TYPE : return "!="; case ELEMENT_NEQ_TYPE : return "!=";
case OPERATOR_ELEMENT_GREATER_TYPE : return ">"; case ELEMENT_GREATER_TYPE : return ">";
case OPERATOR_ELEMENT_GEQ_TYPE : return ">="; case ELEMENT_GEQ_TYPE : return ">=";
case OPERATOR_ELEMENT_LESS_TYPE : return "<"; case ELEMENT_LESS_TYPE : return "<";
case OPERATOR_ELEMENT_LEQ_TYPE : return "<="; case ELEMENT_LEQ_TYPE : return "<=";
case OPERATOR_ELEMENT_FMAX_TYPE : return "fmax"; case ELEMENT_FMAX_TYPE : return "fmax";
case OPERATOR_ELEMENT_FMIN_TYPE : return "fmin"; case ELEMENT_FMIN_TYPE : return "fmin";
case OPERATOR_ELEMENT_MAX_TYPE : return "max"; case ELEMENT_MAX_TYPE : return "max";
case OPERATOR_ELEMENT_MIN_TYPE : return "min"; case ELEMENT_MIN_TYPE : return "min";
//Binary //Binary
case OPERATOR_GEMM_NN_TYPE : return "prodNN"; case MATRIX_PRODUCT_NN_TYPE : return "prodNN";
case OPERATOR_GEMM_TN_TYPE : return "prodTN"; case MATRIX_PRODUCT_TN_TYPE : return "prodTN";
case OPERATOR_GEMM_NT_TYPE : return "prodNT"; case MATRIX_PRODUCT_NT_TYPE : return "prodNT";
case OPERATOR_GEMM_TT_TYPE : return "prodTT"; case MATRIX_PRODUCT_TT_TYPE : return "prodTT";
case OPERATOR_VDIAG_TYPE : return "vdiag"; case VDIAG_TYPE : return "vdiag";
case OPERATOR_MATRIX_DIAG_TYPE : return "mdiag"; case MATRIX_DIAG_TYPE : return "mdiag";
case OPERATOR_MATRIX_ROW_TYPE : return "row"; case MATRIX_ROW_TYPE : return "row";
case OPERATOR_MATRIX_COLUMN_TYPE : return "col"; case MATRIX_COLUMN_TYPE : return "col";
case OPERATOR_PAIR_TYPE: return "pair"; case PAIR_TYPE: return "pair";
case OPERATOR_ACCESS_INDEX_TYPE: return "access"; case ACCESS_INDEX_TYPE: return "access";
//FOR //FOR
case OPERATOR_SFOR_TYPE: return "sfor"; case SFOR_TYPE: return "sfor";
default : throw operation_not_supported_exception("Unsupported operator"); default : throw operation_not_supported_exception("Unsupported operator");
} }
@@ -245,7 +245,7 @@ void evaluate_expression_traversal::call_before_expansion(isaac::math_expression
math_expression::node const & root_node = math_expression.tree()[root_idx]; math_expression::node const & root_node = math_expression.tree()[root_idx];
if(detail::is_cast(root_node.op)) if(detail::is_cast(root_node.op))
str_ += mapping_.at(std::make_pair(root_idx, PARENT_NODE_TYPE))->evaluate(accessors_); str_ += mapping_.at(std::make_pair(root_idx, PARENT_NODE_TYPE))->evaluate(accessors_);
else if (( (root_node.op.type_family==OPERATOR_UNARY_TYPE_FAMILY&&root_node.op.type!=OPERATOR_ADD_TYPE) || detail::is_elementwise_function(root_node.op)) else if (( (root_node.op.type_family==UNARY_TYPE_FAMILY&&root_node.op.type!=ADD_TYPE) || detail::is_elementwise_function(root_node.op))
&& !detail::is_node_leaf(root_node.op)) && !detail::is_node_leaf(root_node.op))
str_+=evaluate(root_node.op.type); str_+=evaluate(root_node.op.type);
if(root_node.op.type!=OPERATOR_FUSE) if(root_node.op.type!=OPERATOR_FUSE)
@@ -268,7 +268,7 @@ void evaluate_expression_traversal::operator()(isaac::math_expression const & ma
{ {
if (detail::is_node_leaf(root_node.op)) if (detail::is_node_leaf(root_node.op))
str_ += mapping_.at(key)->evaluate(accessors_); str_ += mapping_.at(key)->evaluate(accessors_);
else if(root_node.op.type_family!=OPERATOR_UNARY_TYPE_FAMILY) else if(root_node.op.type_family!=UNARY_TYPE_FAMILY)
{ {
if (detail::is_elementwise_operator(root_node.op)) if (detail::is_elementwise_operator(root_node.op))
str_ += evaluate(root_node.op.type); str_ += evaluate(root_node.op.type);
@@ -280,7 +280,7 @@ void evaluate_expression_traversal::operator()(isaac::math_expression const & ma
{ {
if (leaf==LHS_NODE_TYPE) if (leaf==LHS_NODE_TYPE)
{ {
if (root_node.lhs.type_family!=COMPOSITE_OPERATOR_FAMILY) if (root_node.lhs.subtype!=COMPOSITE_OPERATOR_TYPE)
{ {
if (root_node.lhs.subtype==FOR_LOOP_INDEX_TYPE) if (root_node.lhs.subtype==FOR_LOOP_INDEX_TYPE)
str_ += "sforidx" + tools::to_string(root_node.lhs.for_idx.level); str_ += "sforidx" + tools::to_string(root_node.lhs.for_idx.level);
@@ -291,7 +291,7 @@ void evaluate_expression_traversal::operator()(isaac::math_expression const & ma
if (leaf==RHS_NODE_TYPE) if (leaf==RHS_NODE_TYPE)
{ {
if (root_node.rhs.type_family!=COMPOSITE_OPERATOR_FAMILY) if (root_node.rhs.subtype!=COMPOSITE_OPERATOR_TYPE)
{ {
if (root_node.rhs.subtype==FOR_LOOP_INDEX_TYPE) if (root_node.rhs.subtype==FOR_LOOP_INDEX_TYPE)
str_ += "sforidx" + tools::to_string(root_node.rhs.for_idx.level); str_ += "sforidx" + tools::to_string(root_node.rhs.for_idx.level);
@@ -312,14 +312,14 @@ std::string evaluate(leaf_t leaf, std::map<std::string, std::string> const & acc
if (leaf==RHS_NODE_TYPE) if (leaf==RHS_NODE_TYPE)
{ {
if (root_node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY) if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.rhs.node_index, traversal_functor, false); traverse(math_expression, root_node.rhs.node_index, traversal_functor, false);
else else
traversal_functor(math_expression, root_idx, leaf); traversal_functor(math_expression, root_idx, leaf);
} }
else if (leaf==LHS_NODE_TYPE) else if (leaf==LHS_NODE_TYPE)
{ {
if (root_node.lhs.type_family==COMPOSITE_OPERATOR_FAMILY) if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.lhs.node_index, traversal_functor, false); traverse(math_expression, root_node.lhs.node_index, traversal_functor, false);
else else
traversal_functor(math_expression, root_idx, leaf); traversal_functor(math_expression, root_idx, leaf);
@@ -369,14 +369,14 @@ void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::strin
if (leaf==RHS_NODE_TYPE) if (leaf==RHS_NODE_TYPE)
{ {
if (root_node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY) if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.rhs.node_index, traversal_functor, true); traverse(math_expression, root_node.rhs.node_index, traversal_functor, true);
else else
traversal_functor(math_expression, root_idx, leaf); traversal_functor(math_expression, root_idx, leaf);
} }
else if (leaf==LHS_NODE_TYPE) else if (leaf==LHS_NODE_TYPE)
{ {
if (root_node.lhs.type_family==COMPOSITE_OPERATOR_FAMILY) if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.lhs.node_index, traversal_functor, true); traverse(math_expression, root_node.lhs.node_index, traversal_functor, true);
else else
traversal_functor(math_expression, root_idx, leaf); traversal_functor(math_expression, root_idx, leaf);
@@ -440,9 +440,9 @@ void math_expression_representation_functor::append(char*& p, const char * str)
void math_expression_representation_functor::operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf_t) const void math_expression_representation_functor::operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf_t) const
{ {
math_expression::node const & root_node = math_expression.tree()[root_idx]; math_expression::node const & root_node = math_expression.tree()[root_idx];
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.type_family != COMPOSITE_OPERATOR_FAMILY) if (leaf_t==LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
append(root_node.lhs, detail::is_assignment(root_node.op)); append(root_node.lhs, detail::is_assignment(root_node.op));
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.type_family != COMPOSITE_OPERATOR_FAMILY) else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.subtype != COMPOSITE_OPERATOR_TYPE)
append(root_node.rhs, false); append(root_node.rhs, false);
else if (leaf_t==PARENT_NODE_TYPE) else if (leaf_t==PARENT_NODE_TYPE)
append_id(ptr_,root_node.op.type); append_id(ptr_,root_node.op.type);

View File

@@ -39,11 +39,11 @@ bool base::requires_fallback(math_expression const & expression)
int_t base::vector_size(math_expression::node const & node) int_t base::vector_size(math_expression::node const & node)
{ {
if (node.op.type==OPERATOR_MATRIX_DIAG_TYPE) if (node.op.type==MATRIX_DIAG_TYPE)
return std::min<int_t>(node.lhs.array->shape()[0], node.lhs.array->shape()[1]); return std::min<int_t>(node.lhs.array->shape()[0], node.lhs.array->shape()[1]);
else if (node.op.type==OPERATOR_MATRIX_ROW_TYPE) else if (node.op.type==MATRIX_ROW_TYPE)
return node.lhs.array->shape()[1]; return node.lhs.array->shape()[1];
else if (node.op.type==OPERATOR_MATRIX_COLUMN_TYPE) else if (node.op.type==MATRIX_COLUMN_TYPE)
return node.lhs.array->shape()[0]; return node.lhs.array->shape()[0];
else else
return node.lhs.array->shape().max(); return node.lhs.array->shape().max();
@@ -52,12 +52,12 @@ int_t base::vector_size(math_expression::node const & node)
std::pair<int_t, int_t> base::matrix_size(math_expression::container_type const & tree, math_expression::node const & node) std::pair<int_t, int_t> base::matrix_size(math_expression::container_type const & tree, math_expression::node const & node)
{ {
if (node.op.type==OPERATOR_VDIAG_TYPE) if (node.op.type==VDIAG_TYPE)
{ {
int_t size = node.lhs.array->shape()[0]; int_t size = node.lhs.array->shape()[0];
return std::make_pair(size,size); return std::make_pair(size,size);
} }
else if(node.op.type==OPERATOR_REPEAT_TYPE) else if(node.op.type==REPEAT_TYPE)
{ {
size_t rep0 = tuple_get(tree, node.rhs.node_index, 0); size_t rep0 = tuple_get(tree, node.rhs.node_index, 0);
size_t rep1 = tuple_get(tree, node.rhs.node_index, 1); size_t rep1 = tuple_get(tree, node.rhs.node_index, 1);

View File

@@ -75,7 +75,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, math_expre
math_expression::container_type const & tree = expressions.tree(); math_expression::container_type const & tree = expressions.tree();
std::vector<std::size_t> sfors = filter_nodes([](math_expression::node const & node){return node.op.type==OPERATOR_SFOR_TYPE;}, expressions, expressions.root(), true); std::vector<std::size_t> sfors = filter_nodes([](math_expression::node const & node){return node.op.type==SFOR_TYPE;}, expressions, expressions.root(), true);
for(unsigned int i = 0 ; i < sfors.size() ; ++i) for(unsigned int i = 0 ; i < sfors.size() ; ++i)
{ {

View File

@@ -81,11 +81,11 @@ public:
void set_arguments(lhs_rhs_element const & lhs_rhs, bool is_assigned) const void set_arguments(lhs_rhs_element const & lhs_rhs, bool is_assigned) const
{ {
switch(lhs_rhs.type_family) switch(lhs_rhs.subtype)
{ {
case VALUE_TYPE_FAMILY: return set_arguments(lhs_rhs.dtype, lhs_rhs.vscalar); case VALUE_SCALAR_TYPE: return set_arguments(lhs_rhs.dtype, lhs_rhs.vscalar);
case ARRAY_TYPE_FAMILY: return set_arguments(lhs_rhs.array, is_assigned); case DENSE_ARRAY_TYPE: return set_arguments(lhs_rhs.array, is_assigned);
case PLACEHOLDER_TYPE_FAMILY: return; case FOR_LOOP_INDEX_TYPE: return;
default: throw std::runtime_error("Unrecognized type family"); default: throw std::runtime_error("Unrecognized type family");
} }
} }
@@ -93,9 +93,9 @@ public:
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t leaf_t) const void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t leaf_t) const
{ {
math_expression::node const & root_node = math_expression.tree()[root_idx]; math_expression::node const & root_node = math_expression.tree()[root_idx];
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.type_family != COMPOSITE_OPERATOR_FAMILY) if (leaf_t==LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
set_arguments(root_node.lhs, detail::is_assignment(root_node.op)); set_arguments(root_node.lhs, detail::is_assignment(root_node.op));
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.type_family != COMPOSITE_OPERATOR_FAMILY) else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.subtype != COMPOSITE_OPERATOR_TYPE)
set_arguments(root_node.rhs, false); set_arguments(root_node.rhs, false);
} }

View File

@@ -44,11 +44,11 @@ class map_functor : public traversal_functor
std::shared_ptr<mapped_object> create(lhs_rhs_element const & lhs_rhs, bool is_assigned = false) const std::shared_ptr<mapped_object> create(lhs_rhs_element const & lhs_rhs, bool is_assigned = false) const
{ {
switch(lhs_rhs.type_family) switch(lhs_rhs.subtype)
{ {
case VALUE_TYPE_FAMILY: return create(lhs_rhs.dtype, lhs_rhs.vscalar); case VALUE_SCALAR_TYPE: return create(lhs_rhs.dtype, lhs_rhs.vscalar);
case ARRAY_TYPE_FAMILY: return create(lhs_rhs.array, is_assigned); case DENSE_ARRAY_TYPE: return create(lhs_rhs.array, is_assigned);
case PLACEHOLDER_TYPE_FAMILY: return std::shared_ptr<mapped_object>(new mapped_placeholder(lhs_rhs.for_idx.level)); case FOR_LOOP_INDEX_TYPE: return std::shared_ptr<mapped_object>(new mapped_placeholder(lhs_rhs.for_idx.level));
default: throw ""; default: throw "";
} }
} }
@@ -65,31 +65,31 @@ public:
mapping_type::key_type key(root_idx, leaf_t); mapping_type::key_type key(root_idx, leaf_t);
math_expression::node const & root_node = math_expression.tree()[root_idx]; math_expression::node const & root_node = math_expression.tree()[root_idx];
if (leaf_t == LHS_NODE_TYPE && root_node.lhs.type_family != COMPOSITE_OPERATOR_FAMILY) if (leaf_t == LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
mapping_.insert(mapping_type::value_type(key, create(root_node.lhs, detail::is_assignment(root_node.op)))); mapping_.insert(mapping_type::value_type(key, create(root_node.lhs, detail::is_assignment(root_node.op))));
else if (leaf_t == RHS_NODE_TYPE && root_node.rhs.type_family != COMPOSITE_OPERATOR_FAMILY) else if (leaf_t == RHS_NODE_TYPE && root_node.rhs.subtype != COMPOSITE_OPERATOR_TYPE)
mapping_.insert(mapping_type::value_type(key, create(root_node.rhs))); mapping_.insert(mapping_type::value_type(key, create(root_node.rhs)));
else if ( leaf_t== PARENT_NODE_TYPE) else if ( leaf_t== PARENT_NODE_TYPE)
{ {
if (root_node.op.type==OPERATOR_VDIAG_TYPE) if (root_node.op.type==VDIAG_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_vdiag>(&math_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_vdiag>(&math_expression, root_idx, &mapping_)));
else if (root_node.op.type==OPERATOR_MATRIX_DIAG_TYPE) else if (root_node.op.type==MATRIX_DIAG_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_diag>(&math_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_diag>(&math_expression, root_idx, &mapping_)));
else if (root_node.op.type==OPERATOR_MATRIX_ROW_TYPE) else if (root_node.op.type==MATRIX_ROW_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&math_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&math_expression, root_idx, &mapping_)));
else if (root_node.op.type==OPERATOR_MATRIX_COLUMN_TYPE) else if (root_node.op.type==MATRIX_COLUMN_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&math_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&math_expression, root_idx, &mapping_)));
else if(root_node.op.type==OPERATOR_ACCESS_INDEX_TYPE) else if(root_node.op.type==ACCESS_INDEX_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_array_access>(&math_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_array_access>(&math_expression, root_idx, &mapping_)));
else if (detail::is_scalar_reduce_1d(root_node)) else if (detail::is_scalar_reduce_1d(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_1d>(&math_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_1d>(&math_expression, root_idx, &mapping_)));
else if (detail::is_vector_reduce_1d(root_node)) else if (detail::is_vector_reduce_1d(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_2d>(&math_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_2d>(&math_expression, root_idx, &mapping_)));
else if (root_node.op.type_family == OPERATOR_GEMM_TYPE_FAMILY) else if (root_node.op.type_family == MATRIX_PRODUCT_TYPE_FAMILY)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_product>(&math_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_product>(&math_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_REPEAT_TYPE) else if (root_node.op.type == REPEAT_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&math_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&math_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE) else if (root_node.op.type == OUTER_PROD_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&math_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&math_expression, root_idx, &mapping_)));
else if (detail::is_cast(root_node.op)) else if (detail::is_cast(root_node.op))
mapping_.insert(mapping_type::value_type(key, std::shared_ptr<mapped_object>(new mapped_cast(root_node.op.type, binder_.get())))); mapping_.insert(mapping_type::value_type(key, std::shared_ptr<mapped_object>(new mapped_cast(root_node.op.type, binder_.get()))));

View File

@@ -25,10 +25,10 @@ inline void compute_index_reduce_1d(kernel_generation_stream & os, std::string a
// os << acc << " = " << cur_value << ">" << acc_value << "?" << cur << ":" << acc << ";" << std::endl; // os << acc << " = " << cur_value << ">" << acc_value << "?" << cur << ":" << acc << ";" << std::endl;
os << acc << "= select(" << acc << "," << cur << "," << cur_value << ">" << acc_value << ");" << std::endl; os << acc << "= select(" << acc << "," << cur << "," << cur_value << ">" << acc_value << ");" << std::endl;
os << acc_value << "="; os << acc_value << "=";
if (op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE) os << "fmax"; if (op.type==ELEMENT_ARGFMAX_TYPE) os << "fmax";
if (op.type==OPERATOR_ELEMENT_ARGMAX_TYPE) os << "max"; if (op.type==ELEMENT_ARGMAX_TYPE) os << "max";
if (op.type==OPERATOR_ELEMENT_ARGFMIN_TYPE) os << "fmin"; if (op.type==ELEMENT_ARGFMIN_TYPE) os << "fmin";
if (op.type==OPERATOR_ELEMENT_ARGMIN_TYPE) os << "min"; if (op.type==ELEMENT_ARGMIN_TYPE) os << "min";
os << "(" << acc_value << "," << cur_value << ");"<< std::endl; os << "(" << acc_value << "," << cur_value << ");"<< std::endl;
} }
@@ -39,17 +39,17 @@ inline std::string neutral_element(op_element const & op, driver::backend_type b
switch (op.type) switch (op.type)
{ {
case OPERATOR_ADD_TYPE : return "0"; case ADD_TYPE : return "0";
case OPERATOR_MULT_TYPE : return "1"; case MULT_TYPE : return "1";
case OPERATOR_DIV_TYPE : return "1"; case DIV_TYPE : return "1";
case OPERATOR_ELEMENT_FMAX_TYPE : return N_INF; case ELEMENT_FMAX_TYPE : return N_INF;
case OPERATOR_ELEMENT_ARGFMAX_TYPE : return N_INF; case ELEMENT_ARGFMAX_TYPE : return N_INF;
case OPERATOR_ELEMENT_MAX_TYPE : return N_INF; case ELEMENT_MAX_TYPE : return N_INF;
case OPERATOR_ELEMENT_ARGMAX_TYPE : return N_INF; case ELEMENT_ARGMAX_TYPE : return N_INF;
case OPERATOR_ELEMENT_FMIN_TYPE : return INF; case ELEMENT_FMIN_TYPE : return INF;
case OPERATOR_ELEMENT_ARGFMIN_TYPE : return INF; case ELEMENT_ARGFMIN_TYPE : return INF;
case OPERATOR_ELEMENT_MIN_TYPE : return INF; case ELEMENT_MIN_TYPE : return INF;
case OPERATOR_ELEMENT_ARGMIN_TYPE : return INF; case ELEMENT_ARGMIN_TYPE : return INF;
default: throw std::runtime_error("Unsupported reduce_1d operator : no neutral element known"); default: throw std::runtime_error("Unsupported reduce_1d operator : no neutral element known");
} }
@@ -57,18 +57,18 @@ inline std::string neutral_element(op_element const & op, driver::backend_type b
inline bool is_reduce_1d(math_expression::node const & node) inline bool is_reduce_1d(math_expression::node const & node)
{ {
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY return node.op.type_family==VECTOR_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY || node.op.type_family==COLUMNS_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY; || node.op.type_family==ROWS_DOT_TYPE_FAMILY;
} }
inline bool is_index_reduction(op_element const & op) inline bool is_index_reduction(op_element const & op)
{ {
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE return op.type==ELEMENT_ARGFMAX_TYPE
|| op.type==OPERATOR_ELEMENT_ARGMAX_TYPE || op.type==ELEMENT_ARGMAX_TYPE
|| op.type==OPERATOR_ELEMENT_ARGFMIN_TYPE || op.type==ELEMENT_ARGFMIN_TYPE
|| op.type==OPERATOR_ELEMENT_ARGMIN_TYPE; || op.type==ELEMENT_ARGMIN_TYPE;
} }

View File

@@ -32,27 +32,27 @@ namespace isaac
bool result = false; bool result = false;
switch(op.type_family) switch(op.type_family)
{ {
case OPERATOR_UNARY_TYPE_FAMILY: case UNARY_TYPE_FAMILY:
case OPERATOR_BINARY_TYPE_FAMILY: case BINARY_TYPE_FAMILY:
result |= is_mmprod(expression) result |= is_mmprod(expression)
|| (result |= expression==REDUCE_2D_ROWS && other==REDUCE_2D_COLS) || (result |= expression==REDUCE_2D_ROWS && other==REDUCE_2D_COLS)
|| (result |= expression==REDUCE_2D_COLS && other==REDUCE_2D_ROWS); || (result |= expression==REDUCE_2D_COLS && other==REDUCE_2D_ROWS);
break; break;
case OPERATOR_VECTOR_DOT_TYPE_FAMILY: case VECTOR_DOT_TYPE_FAMILY:
result |= is_mvprod(expression) result |= is_mvprod(expression)
|| expression==REDUCE_1D; || expression==REDUCE_1D;
break; break;
case OPERATOR_ROWS_DOT_TYPE_FAMILY: case ROWS_DOT_TYPE_FAMILY:
result |= is_mmprod(expression) result |= is_mmprod(expression)
|| is_mvprod(expression) || is_mvprod(expression)
|| expression==REDUCE_1D; || expression==REDUCE_1D;
break; break;
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY: case COLUMNS_DOT_TYPE_FAMILY:
result |= is_mmprod(expression) result |= is_mmprod(expression)
|| is_mvprod(expression) || is_mvprod(expression)
|| expression==REDUCE_1D; || expression==REDUCE_1D;
break; break;
case OPERATOR_GEMM_TYPE_FAMILY: case MATRIX_PRODUCT_TYPE_FAMILY:
result |= (is_mmprod(expression) && !is_first) result |= (is_mmprod(expression) && !is_first)
|| is_mvprod(expression) || is_mvprod(expression)
|| expression==REDUCE_1D; || expression==REDUCE_1D;
@@ -74,30 +74,30 @@ namespace isaac
{ {
switch(op.type_family) switch(op.type_family)
{ {
case OPERATOR_UNARY_TYPE_FAMILY: case UNARY_TYPE_FAMILY:
if(is_mmprod(left)) if(is_mmprod(left))
return ELEMENTWISE_2D; return ELEMENTWISE_2D;
return left; return left;
case OPERATOR_BINARY_TYPE_FAMILY: case BINARY_TYPE_FAMILY:
if(left == REDUCE_2D_ROWS || right == REDUCE_2D_ROWS) return REDUCE_2D_ROWS; if(left == REDUCE_2D_ROWS || right == REDUCE_2D_ROWS) return REDUCE_2D_ROWS;
else if(left == REDUCE_2D_COLS || right == REDUCE_2D_COLS) return REDUCE_2D_COLS; else if(left == REDUCE_2D_COLS || right == REDUCE_2D_COLS) return REDUCE_2D_COLS;
else if(left == REDUCE_1D || right == REDUCE_1D) return REDUCE_1D; else if(left == REDUCE_1D || right == REDUCE_1D) return REDUCE_1D;
else if(left == ELEMENTWISE_2D || right == ELEMENTWISE_2D) return ELEMENTWISE_2D; else if(left == ELEMENTWISE_2D || right == ELEMENTWISE_2D) return ELEMENTWISE_2D;
else if(left == ELEMENTWISE_1D || right == ELEMENTWISE_1D) return op.type==OPERATOR_OUTER_PROD_TYPE?ELEMENTWISE_2D:ELEMENTWISE_1D; else if(left == ELEMENTWISE_1D || right == ELEMENTWISE_1D) return op.type==OUTER_PROD_TYPE?ELEMENTWISE_2D:ELEMENTWISE_1D;
else if(is_mmprod(left) || is_mmprod(right)) return ELEMENTWISE_2D; else if(is_mmprod(left) || is_mmprod(right)) return ELEMENTWISE_2D;
else if(right == INVALID_EXPRESSION_TYPE) return left; else if(right == INVALID_EXPRESSION_TYPE) return left;
else if(left == INVALID_EXPRESSION_TYPE) return right; else if(left == INVALID_EXPRESSION_TYPE) return right;
throw; throw;
case OPERATOR_VECTOR_DOT_TYPE_FAMILY: case VECTOR_DOT_TYPE_FAMILY:
return REDUCE_1D; return REDUCE_1D;
case OPERATOR_ROWS_DOT_TYPE_FAMILY: case ROWS_DOT_TYPE_FAMILY:
return REDUCE_2D_ROWS; return REDUCE_2D_ROWS;
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY: case COLUMNS_DOT_TYPE_FAMILY:
return REDUCE_2D_COLS; return REDUCE_2D_COLS;
case OPERATOR_GEMM_TYPE_FAMILY: case MATRIX_PRODUCT_TYPE_FAMILY:
if(op.type==OPERATOR_GEMM_NN_TYPE) return MATRIX_PRODUCT_NN; if(op.type==MATRIX_PRODUCT_NN_TYPE) return MATRIX_PRODUCT_NN;
else if(op.type==OPERATOR_GEMM_TN_TYPE) return MATRIX_PRODUCT_TN; else if(op.type==MATRIX_PRODUCT_TN_TYPE) return MATRIX_PRODUCT_TN;
else if(op.type==OPERATOR_GEMM_NT_TYPE) return MATRIX_PRODUCT_NT; else if(op.type==MATRIX_PRODUCT_NT_TYPE) return MATRIX_PRODUCT_NT;
else return MATRIX_PRODUCT_TT; else return MATRIX_PRODUCT_TT;
default: default:
throw; throw;
@@ -115,11 +115,11 @@ namespace isaac
auto ng1 = [](shape_t const & shape){ size_t res = 0 ; for(size_t i = 0 ; i < shape.size() ; ++i) res += (shape[i] > 1); return res;}; auto ng1 = [](shape_t const & shape){ size_t res = 0 ; for(size_t i = 0 ; i < shape.size() ; ++i) res += (shape[i] > 1); return res;};
//Left //Left
expression_type type_left = INVALID_EXPRESSION_TYPE; expression_type type_left = INVALID_EXPRESSION_TYPE;
if (node.lhs.type_family == COMPOSITE_OPERATOR_FAMILY) if (node.lhs.subtype == COMPOSITE_OPERATOR_TYPE)
parse(array, node.lhs.node_index, breakpoints, type_left, false); parse(array, node.lhs.node_index, breakpoints, type_left, false);
else if(node.lhs.subtype == DENSE_ARRAY_TYPE) else if(node.lhs.subtype == DENSE_ARRAY_TYPE)
{ {
if(node.op.type==OPERATOR_MATRIX_ROW_TYPE || node.op.type==OPERATOR_MATRIX_COLUMN_TYPE || ng1(node.lhs.array->shape())<=1) if(node.op.type==MATRIX_ROW_TYPE || node.op.type==MATRIX_COLUMN_TYPE || ng1(node.lhs.array->shape())<=1)
type_left = ELEMENTWISE_1D; type_left = ELEMENTWISE_1D;
else else
type_left = ELEMENTWISE_2D; type_left = ELEMENTWISE_2D;
@@ -127,11 +127,11 @@ namespace isaac
//Right //Right
expression_type type_right = INVALID_EXPRESSION_TYPE; expression_type type_right = INVALID_EXPRESSION_TYPE;
if (node.rhs.type_family == COMPOSITE_OPERATOR_FAMILY) if (node.rhs.subtype == COMPOSITE_OPERATOR_TYPE)
parse(array, node.rhs.node_index, breakpoints, type_right, false); parse(array, node.rhs.node_index, breakpoints, type_right, false);
else if(node.rhs.subtype == DENSE_ARRAY_TYPE) else if(node.rhs.subtype == DENSE_ARRAY_TYPE)
{ {
if(node.op.type==OPERATOR_MATRIX_ROW_TYPE || node.op.type==OPERATOR_MATRIX_COLUMN_TYPE || ng1(node.rhs.array->shape())<=1) if(node.op.type==MATRIX_ROW_TYPE || node.op.type==MATRIX_COLUMN_TYPE || ng1(node.rhs.array->shape())<=1)
type_right = ELEMENTWISE_1D; type_right = ELEMENTWISE_1D;
else else
type_right = ELEMENTWISE_2D; type_right = ELEMENTWISE_2D;
@@ -160,7 +160,7 @@ namespace isaac
std::vector<std::shared_ptr<array> > temporaries_; std::vector<std::shared_ptr<array> > temporaries_;
expression_type final_type; expression_type final_type;
//GEMM //MATRIX_PRODUCT
if(symbolic::preset::matrix_product::args args = symbolic::preset::matrix_product::check(tree, rootidx)){ if(symbolic::preset::matrix_product::args args = symbolic::preset::matrix_product::check(tree, rootidx)){
final_type = args.type; final_type = args.type;
} }
@@ -208,10 +208,10 @@ namespace isaac
} }
temporaries_.push_back(tmp); temporaries_.push_back(tmp);
tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE; tree[rootidx].op.type = ASSIGN_TYPE;
fill(tree[rootidx].lhs, (array&)*tmp); fill(tree[rootidx].lhs, (array&)*tmp);
tree[rootidx].rhs = *it->second; tree[rootidx].rhs = *it->second;
tree[rootidx].rhs.type_family = it->second->type_family; tree[rootidx].rhs.subtype = it->second->subtype;
//Execute //Execute
profile->execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options())); profile->execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));

View File

@@ -10,22 +10,19 @@ namespace isaac
void fill(lhs_rhs_element &x, invalid_node) void fill(lhs_rhs_element &x, invalid_node)
{ {
x.type_family = INVALID_TYPE_FAMILY;
x.subtype = INVALID_SUBTYPE; x.subtype = INVALID_SUBTYPE;
x.dtype = INVALID_NUMERIC_TYPE; x.dtype = INVALID_NUMERIC_TYPE;
} }
void fill(lhs_rhs_element & x, std::size_t node_index) void fill(lhs_rhs_element & x, std::size_t node_index)
{ {
x.type_family = COMPOSITE_OPERATOR_FAMILY; x.subtype = COMPOSITE_OPERATOR_TYPE;
x.subtype = INVALID_SUBTYPE;
x.dtype = INVALID_NUMERIC_TYPE; x.dtype = INVALID_NUMERIC_TYPE;
x.node_index = node_index; x.node_index = node_index;
} }
void fill(lhs_rhs_element & x, for_idx_t index) void fill(lhs_rhs_element & x, for_idx_t index)
{ {
x.type_family = PLACEHOLDER_TYPE_FAMILY;
x.subtype = FOR_LOOP_INDEX_TYPE; x.subtype = FOR_LOOP_INDEX_TYPE;
x.dtype = INVALID_NUMERIC_TYPE; x.dtype = INVALID_NUMERIC_TYPE;
x.for_idx = index; x.for_idx = index;
@@ -33,7 +30,6 @@ void fill(lhs_rhs_element & x, for_idx_t index)
void fill(lhs_rhs_element & x, array_base const & a) void fill(lhs_rhs_element & x, array_base const & a)
{ {
x.type_family = ARRAY_TYPE_FAMILY;
x.subtype = DENSE_ARRAY_TYPE; x.subtype = DENSE_ARRAY_TYPE;
x.dtype = a.dtype(); x.dtype = a.dtype();
x.array = (array_base*)&a; x.array = (array_base*)&a;
@@ -41,7 +37,6 @@ void fill(lhs_rhs_element & x, array_base const & a)
void fill(lhs_rhs_element & x, value_scalar const & v) void fill(lhs_rhs_element & x, value_scalar const & v)
{ {
x.type_family = VALUE_TYPE_FAMILY;
x.dtype = v.dtype(); x.dtype = v.dtype();
x.subtype = VALUE_SCALAR_TYPE; x.subtype = VALUE_SCALAR_TYPE;
x.vscalar = v.values(); x.vscalar = v.values();
@@ -51,7 +46,7 @@ lhs_rhs_element::lhs_rhs_element(){}
// //
op_element::op_element() {} op_element::op_element() {}
op_element::op_element(operation_node_type_family const & _type_family, operation_node_type const & _type) : type_family(_type_family), type(_type){} op_element::op_element(operation_type_family const & _type_family, operation_type const & _type) : type_family(_type_family), type(_type){}
// //
math_expression::math_expression(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op) math_expression::math_expression(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op)
@@ -122,8 +117,8 @@ math_expression::math_expression(math_expression const & lhs, math_expression co
tree_[root_].op = op; tree_[root_].op = op;
fill(tree_[root_].rhs, lsize + rhs.root_); fill(tree_[root_].rhs, lsize + rhs.root_);
for(container_type::iterator it = tree_.begin() + lsize ; it != tree_.end() - 1 ; ++it){ for(container_type::iterator it = tree_.begin() + lsize ; it != tree_.end() - 1 ; ++it){
if(it->lhs.type_family==COMPOSITE_OPERATOR_FAMILY) it->lhs.node_index+=lsize; if(it->lhs.subtype==COMPOSITE_OPERATOR_TYPE) it->lhs.node_index+=lsize;
if(it->rhs.type_family==COMPOSITE_OPERATOR_FAMILY) it->rhs.node_index+=lsize; if(it->rhs.subtype==COMPOSITE_OPERATOR_TYPE) it->rhs.node_index+=lsize;
} }
root_ = tree_.size() - 1; root_ = tree_.size() - 1;
} }
@@ -181,17 +176,17 @@ int_t math_expression::dim() const
//} //}
math_expression math_expression::operator-() math_expression math_expression::operator-()
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), *context_, dtype_, shape_); } { return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), *context_, dtype_, shape_); }
math_expression math_expression::operator!() math_expression math_expression::operator!()
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_NEGATE_TYPE), *context_, INT_TYPE, shape_); } { return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), *context_, INT_TYPE, shape_); }
// //
math_expression::node const & lhs_most(math_expression::container_type const & array, math_expression::node const & init) math_expression::node const & lhs_most(math_expression::container_type const & array, math_expression::node const & init)
{ {
math_expression::node const * current = &init; math_expression::node const * current = &init;
while (current->lhs.type_family==COMPOSITE_OPERATOR_FAMILY) while (current->lhs.subtype==COMPOSITE_OPERATOR_TYPE)
current = &array[current->lhs.node_index]; current = &array[current->lhs.node_index];
return *current; return *current;
} }
@@ -200,8 +195,8 @@ math_expression::node const & lhs_most(math_expression::container_type const & a
{ return lhs_most(array, array[root]); } { return lhs_most(array, array[root]); }
// //
math_expression for_idx_t::operator=(value_scalar const & r) const { return math_expression(*this, r, op_element(OPERATOR_BINARY_TYPE_FAMILY,OPERATOR_ASSIGN_TYPE), r.dtype()); } math_expression for_idx_t::operator=(value_scalar const & r) const { return math_expression(*this, r, op_element(BINARY_TYPE_FAMILY,ASSIGN_TYPE), r.dtype()); }
math_expression for_idx_t::operator=(math_expression const & r) const { return math_expression(*this, r, op_element(OPERATOR_BINARY_TYPE_FAMILY,OPERATOR_ASSIGN_TYPE), r.context(), r.dtype(), r.shape()); } math_expression for_idx_t::operator=(math_expression const & r) const { return math_expression(*this, r, op_element(BINARY_TYPE_FAMILY,ASSIGN_TYPE), r.context(), r.dtype(), r.shape()); }
math_expression for_idx_t::operator+=(value_scalar const & r) const { return *this = *this + r; } math_expression for_idx_t::operator+=(value_scalar const & r) const { return *this = *this + r; }
math_expression for_idx_t::operator-=(value_scalar const & r) const { return *this = *this - r; } math_expression for_idx_t::operator-=(value_scalar const & r) const { return *this = *this - r; }

View File

@@ -10,7 +10,7 @@ namespace isaac
#define ISAAC_MAP_TO_STRING(NAME) case NAME: return #NAME #define ISAAC_MAP_TO_STRING(NAME) case NAME: return #NAME
inline std::string to_string(math_expression_node_subtype const & f) inline std::string to_string(node_type const & f)
{ {
switch(f) switch(f)
{ {
@@ -23,7 +23,7 @@ inline std::string to_string(math_expression_node_subtype const & f)
inline std::string to_string(lhs_rhs_element const & e) inline std::string to_string(lhs_rhs_element const & e)
{ {
if(e.type_family==COMPOSITE_OPERATOR_FAMILY) if(e.subtype==COMPOSITE_OPERATOR_TYPE)
{ {
return"COMPOSITE [" + tools::to_string(e.node_index) + "]"; return"COMPOSITE [" + tools::to_string(e.node_index) + "]";
} }
@@ -53,10 +53,10 @@ namespace detail
os << "Node " << node_index << ": " << current_node << std::endl; os << "Node " << node_index << ": " << current_node << std::endl;
if (current_node.lhs.type_family == COMPOSITE_OPERATOR_FAMILY) if (current_node.lhs.subtype == COMPOSITE_OPERATOR_TYPE)
print_node(os, s, current_node.lhs.node_index, indent+1); print_node(os, s, current_node.lhs.node_index, indent+1);
if (current_node.rhs.type_family == COMPOSITE_OPERATOR_FAMILY) if (current_node.rhs.subtype == COMPOSITE_OPERATOR_TYPE)
print_node(os, s, current_node.rhs.node_index, indent+1); print_node(os, s, current_node.rhs.node_index, indent+1);
} }
} }

View File

@@ -12,33 +12,33 @@ namespace preset
void matrix_product::handle_node(math_expression::container_type const & tree, size_t rootidx, args & a) void matrix_product::handle_node(math_expression::container_type const & tree, size_t rootidx, args & a)
{ {
//Matrix-Matrix product node //Matrix-Matrix product node
if(tree[rootidx].op.type_family==OPERATOR_GEMM_TYPE_FAMILY) if(tree[rootidx].op.type_family==MATRIX_PRODUCT_TYPE_FAMILY)
{ {
if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY) a.A = &tree[rootidx].lhs; if(tree[rootidx].lhs.subtype==DENSE_ARRAY_TYPE) a.A = &tree[rootidx].lhs;
if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs; if(tree[rootidx].rhs.subtype==DENSE_ARRAY_TYPE) a.B = &tree[rootidx].rhs;
switch(tree[rootidx].op.type) switch(tree[rootidx].op.type)
{ {
case OPERATOR_GEMM_NN_TYPE: a.type = MATRIX_PRODUCT_NN; break; case MATRIX_PRODUCT_NN_TYPE: a.type = MATRIX_PRODUCT_NN; break;
case OPERATOR_GEMM_NT_TYPE: a.type = MATRIX_PRODUCT_NT; break; case MATRIX_PRODUCT_NT_TYPE: a.type = MATRIX_PRODUCT_NT; break;
case OPERATOR_GEMM_TN_TYPE: a.type = MATRIX_PRODUCT_TN; break; case MATRIX_PRODUCT_TN_TYPE: a.type = MATRIX_PRODUCT_TN; break;
case OPERATOR_GEMM_TT_TYPE: a.type = MATRIX_PRODUCT_TT; break; case MATRIX_PRODUCT_TT_TYPE: a.type = MATRIX_PRODUCT_TT; break;
default: break; default: break;
} }
} }
//Scalar multiplication node //Scalar multiplication node
if(tree[rootidx].op.type==OPERATOR_MULT_TYPE) if(tree[rootidx].op.type==MULT_TYPE)
{ {
//alpha*PROD //alpha*PROD
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY if(tree[rootidx].lhs.subtype==VALUE_SCALAR_TYPE && tree[rootidx].rhs.subtype==COMPOSITE_OPERATOR_TYPE
&& tree[tree[rootidx].rhs.node_index].op.type_family==OPERATOR_GEMM_TYPE_FAMILY) && tree[tree[rootidx].rhs.node_index].op.type_family==MATRIX_PRODUCT_TYPE_FAMILY)
{ {
a.alpha = value_scalar(tree[rootidx].lhs.vscalar, tree[rootidx].lhs.dtype); a.alpha = value_scalar(tree[rootidx].lhs.vscalar, tree[rootidx].lhs.dtype);
handle_node(tree, tree[rootidx].rhs.node_index, a); handle_node(tree, tree[rootidx].rhs.node_index, a);
} }
//beta*C //beta*C
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) if(tree[rootidx].lhs.subtype==VALUE_SCALAR_TYPE && tree[rootidx].rhs.subtype==DENSE_ARRAY_TYPE)
{ {
a.beta = value_scalar(tree[rootidx].lhs.vscalar, tree[rootidx].lhs.dtype); a.beta = value_scalar(tree[rootidx].lhs.vscalar, tree[rootidx].lhs.dtype);
a.C = &tree[rootidx].rhs; a.C = &tree[rootidx].rhs;
@@ -55,26 +55,26 @@ matrix_product::args matrix_product::check(math_expression::container_type const
return result; return result;
result.alpha = value_scalar(1, dtype); result.alpha = value_scalar(1, dtype);
result.beta = value_scalar(0, dtype); result.beta = value_scalar(0, dtype);
if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY) if(tree[rootidx].rhs.subtype==COMPOSITE_OPERATOR_TYPE)
{ {
rootidx = tree[rootidx].rhs.node_index; rootidx = tree[rootidx].rhs.node_index;
bool is_add = tree[rootidx].op.type==OPERATOR_ADD_TYPE; bool is_add = tree[rootidx].op.type==ADD_TYPE;
bool is_sub = tree[rootidx].op.type==OPERATOR_SUB_TYPE; bool is_sub = tree[rootidx].op.type==SUB_TYPE;
//Form X +- Y" //Form X +- Y"
if(is_add || is_sub) if(is_add || is_sub)
{ {
if(tree[rootidx].lhs.type_family==COMPOSITE_OPERATOR_FAMILY) if(tree[rootidx].lhs.subtype==COMPOSITE_OPERATOR_TYPE)
handle_node(tree, tree[rootidx].lhs.node_index, result); handle_node(tree, tree[rootidx].lhs.node_index, result);
else if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY) else if(tree[rootidx].lhs.subtype==DENSE_ARRAY_TYPE)
{ {
result.C = &tree[rootidx].lhs; result.C = &tree[rootidx].lhs;
result.beta = value_scalar(1, dtype); result.beta = value_scalar(1, dtype);
result.alpha = value_scalar(is_add?1:-1, dtype); result.alpha = value_scalar(is_add?1:-1, dtype);
} }
if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY) if(tree[rootidx].rhs.subtype==COMPOSITE_OPERATOR_TYPE)
handle_node(tree, tree[rootidx].rhs.node_index, result); handle_node(tree, tree[rootidx].rhs.node_index, result);
else if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) else if(tree[rootidx].rhs.subtype==DENSE_ARRAY_TYPE)
{ {
result.C = &tree[rootidx].rhs; result.C = &tree[rootidx].rhs;
result.alpha = value_scalar(1, dtype); result.alpha = value_scalar(1, dtype);

View File

@@ -169,7 +169,7 @@ extern "C"
//BLAS3 //BLAS3
//***************** //*****************
#define MAKE_GEMM(TYPE_CHAR, TYPE_ISAAC, TYPE_CL) \ #define MAKE_MATRIX_PRODUCT(TYPE_CHAR, TYPE_ISAAC, TYPE_CL) \
clblasStatus clblas ## TYPE_CHAR ## gemm(clblasOrder order, clblasTranspose transA, clblasTranspose transB,\ clblasStatus clblas ## TYPE_CHAR ## gemm(clblasOrder order, clblasTranspose transA, clblasTranspose transB,\
size_t M, size_t N, size_t K,\ size_t M, size_t N, size_t K,\
TYPE_CL alpha, const cl_mem cmA, size_t offA, size_t lda,\ TYPE_CL alpha, const cl_mem cmA, size_t offA, size_t lda,\
@@ -208,8 +208,8 @@ extern "C"
return clblasSuccess;\ return clblasSuccess;\
} }
MAKE_GEMM(S, sc::FLOAT_TYPE, cl_float) MAKE_MATRIX_PRODUCT(S, sc::FLOAT_TYPE, cl_float)
MAKE_GEMM(D, sc::DOUBLE_TYPE, cl_double) MAKE_MATRIX_PRODUCT(D, sc::DOUBLE_TYPE, cl_double)
#undef DOT #undef DOT

View File

@@ -194,7 +194,7 @@ extern "C"
//BLAS3 //BLAS3
//***************** //*****************
#define MAKE_GEMM(TYPE_CHAR, TYPE_ISAAC, TYPE_CU) \ #define MAKE_MATRIX_PRODUCT(TYPE_CHAR, TYPE_ISAAC, TYPE_CU) \
void cublas ## TYPE_CHAR ## gemm (char transa, char transb, int m, int n, int k,\ void cublas ## TYPE_CHAR ## gemm (char transa, char transb, int m, int n, int k,\
TYPE_CU alpha, const TYPE_CU *A, int lda,\ TYPE_CU alpha, const TYPE_CU *A, int lda,\
const TYPE_CU *B, int ldb, TYPE_CU beta, TYPE_CU *C,\ const TYPE_CU *B, int ldb, TYPE_CU beta, TYPE_CU *C,\
@@ -235,6 +235,6 @@ extern "C"
return CUBLAS_STATUS_SUCCESS;\ return CUBLAS_STATUS_SUCCESS;\
} }
MAKE_GEMM(S, sc::FLOAT_TYPE, cl_float) MAKE_MATRIX_PRODUCT(S, sc::FLOAT_TYPE, cl_float)
MAKE_GEMM(D, sc::DOUBLE_TYPE, cl_double) MAKE_MATRIX_PRODUCT(D, sc::DOUBLE_TYPE, cl_double)
} }