Code Quality: heavy renaming and cleaning
This commit is contained in:
12
include/external/clBLAS.h
vendored
12
include/external/clBLAS.h
vendored
@@ -44,7 +44,7 @@ extern "C" {
|
|||||||
* SWAP, SCAL, COPY, AXPY, DOT, DOTU, DOTC, ROTG, ROTMG, ROT, ROTM, iAMAX, ASUM and NRM2,
|
* SWAP, SCAL, COPY, AXPY, DOT, DOTU, DOTC, ROTG, ROTMG, ROT, ROTM, iAMAX, ASUM and NRM2,
|
||||||
* BLAS-2 functions GEMV, SYMV, TRMV, TRSV, HEMV, SYR, SYR2, HER, HER2, GER, GERU, GERC,
|
* BLAS-2 functions GEMV, SYMV, TRMV, TRSV, HEMV, SYR, SYR2, HER, HER2, GER, GERU, GERC,
|
||||||
* TPMV, SPMV, HPMV, TPSV, SPR, SPR2, HPR, HPR2, GBMV, TBMV, SBMV, HBMV and TBSV
|
* TPMV, SPMV, HPMV, TPSV, SPR, SPR2, HPR, HPR2, GBMV, TBMV, SBMV, HBMV and TBSV
|
||||||
* and BLAS-3 functions GEMM, SYMM, TRMM, TRSM, HEMM, HERK, HER2K, SYRK and SYR2K.
|
* and BLAS-3 functions MATRIX_PRODUCT, SYMM, TRMM, TRSM, HEMM, HERK, HER2K, SYRK and SYR2K.
|
||||||
*
|
*
|
||||||
* This library’s primary goal is to assist the end user to enqueue OpenCL
|
* This library’s primary goal is to assist the end user to enqueue OpenCL
|
||||||
* kernels to process BLAS functions in an OpenCL-efficient manner, while
|
* kernels to process BLAS functions in an OpenCL-efficient manner, while
|
||||||
@@ -7314,7 +7314,7 @@ clblasZtbsv(
|
|||||||
/*@}*/
|
/*@}*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @defgroup GEMM GEMM - General matrix-matrix multiplication
|
* @defgroup MATRIX_PRODUCT MATRIX_PRODUCT - General matrix-matrix multiplication
|
||||||
* @ingroup BLAS3
|
* @ingroup BLAS3
|
||||||
*/
|
*/
|
||||||
/*@{*/
|
/*@{*/
|
||||||
@@ -7372,7 +7372,7 @@ clblasZtbsv(
|
|||||||
* the size of the respective buffer object;
|
* the size of the respective buffer object;
|
||||||
* - the same error codes as clblasSgemm() otherwise.
|
* - the same error codes as clblasSgemm() otherwise.
|
||||||
*
|
*
|
||||||
* @ingroup GEMM
|
* @ingroup MATRIX_PRODUCT
|
||||||
*/
|
*/
|
||||||
clblasStatus
|
clblasStatus
|
||||||
clblasSgemm(
|
clblasSgemm(
|
||||||
@@ -7453,7 +7453,7 @@ clblasSgemm(
|
|||||||
* the size of the respective buffer object;
|
* the size of the respective buffer object;
|
||||||
* - the same error codes as the clblasSgemm() function otherwise.
|
* - the same error codes as the clblasSgemm() function otherwise.
|
||||||
*
|
*
|
||||||
* @ingroup GEMM
|
* @ingroup MATRIX_PRODUCT
|
||||||
*/
|
*/
|
||||||
clblasStatus
|
clblasStatus
|
||||||
clblasDgemm(
|
clblasDgemm(
|
||||||
@@ -7527,7 +7527,7 @@ clblasDgemm(
|
|||||||
* the size of the respective buffer object;
|
* the size of the respective buffer object;
|
||||||
* - the same error codes as the clblasSgemm() function otherwise.
|
* - the same error codes as the clblasSgemm() function otherwise.
|
||||||
*
|
*
|
||||||
* @ingroup GEMM
|
* @ingroup MATRIX_PRODUCT
|
||||||
*/
|
*/
|
||||||
clblasStatus
|
clblasStatus
|
||||||
clblasCgemm(
|
clblasCgemm(
|
||||||
@@ -7603,7 +7603,7 @@ clblasCgemm(
|
|||||||
* the size of the respective buffer object;
|
* the size of the respective buffer object;
|
||||||
* - the same error codes as the clblasSgemm() function otherwise.
|
* - the same error codes as the clblasSgemm() function otherwise.
|
||||||
*
|
*
|
||||||
* @ingroup GEMM
|
* @ingroup MATRIX_PRODUCT
|
||||||
*/
|
*/
|
||||||
clblasStatus
|
clblasStatus
|
||||||
clblasZgemm(
|
clblasZgemm(
|
||||||
|
2
include/external/cuda/cublas.h
vendored
2
include/external/cuda/cublas.h
vendored
@@ -418,7 +418,7 @@ void CUBLASWINAPI cublasZhpr2 (char uplo, int n, cuDoubleComplex alpha,
|
|||||||
const cuDoubleComplex *x, int incx, const cuDoubleComplex *y,
|
const cuDoubleComplex *x, int incx, const cuDoubleComplex *y,
|
||||||
int incy, cuDoubleComplex *AP);
|
int incy, cuDoubleComplex *AP);
|
||||||
/* ------------------------BLAS3 Functions ------------------------------- */
|
/* ------------------------BLAS3 Functions ------------------------------- */
|
||||||
/* GEMM */
|
/* MATRIX_PRODUCT */
|
||||||
void CUBLASWINAPI cublasSgemm (char transa, char transb, int m, int n, int k,
|
void CUBLASWINAPI cublasSgemm (char transa, char transb, int m, int n, int k,
|
||||||
float alpha, const float *A, int lda,
|
float alpha, const float *A, int lda,
|
||||||
const float *B, int ldb, float beta, float *C,
|
const float *B, int ldb, float beta, float *C,
|
||||||
|
4
include/external/cuda/cublas_api.h
vendored
4
include/external/cuda/cublas_api.h
vendored
@@ -1508,7 +1508,7 @@ CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhpr2_v2 (cublasHandle_t handle,
|
|||||||
|
|
||||||
/* ---------------- CUBLAS BLAS3 functions ---------------- */
|
/* ---------------- CUBLAS BLAS3 functions ---------------- */
|
||||||
|
|
||||||
/* GEMM */
|
/* MATRIX_PRODUCT */
|
||||||
CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemm_v2 (cublasHandle_t handle,
|
CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemm_v2 (cublasHandle_t handle,
|
||||||
cublasOperation_t transa,
|
cublasOperation_t transa,
|
||||||
cublasOperation_t transb,
|
cublasOperation_t transb,
|
||||||
@@ -2042,7 +2042,7 @@ CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrmm_v2(cublasHandle_t handle, cubl
|
|||||||
int ldb,
|
int ldb,
|
||||||
cuDoubleComplex *C,
|
cuDoubleComplex *C,
|
||||||
int ldc);
|
int ldc);
|
||||||
/* BATCH GEMM */
|
/* BATCH MATRIX_PRODUCT */
|
||||||
CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmBatched (cublasHandle_t handle,
|
CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmBatched (cublasHandle_t handle,
|
||||||
cublasOperation_t transa,
|
cublasOperation_t transa,
|
||||||
cublasOperation_t transb,
|
cublasOperation_t transb,
|
||||||
|
@@ -248,9 +248,9 @@ public:
|
|||||||
|
|
||||||
class mapped_cast : public mapped_object
|
class mapped_cast : public mapped_object
|
||||||
{
|
{
|
||||||
static std::string operator_to_str(operation_node_type type);
|
static std::string operator_to_str(operation_type type);
|
||||||
public:
|
public:
|
||||||
mapped_cast(operation_node_type type, unsigned int id);
|
mapped_cast(operation_type type, unsigned int id);
|
||||||
};
|
};
|
||||||
|
|
||||||
extern mapped_object& get(math_expression::container_type const &, size_t, mapping_type const &, size_t);
|
extern mapped_object& get(math_expression::container_type const &, size_t, mapping_type const &, size_t);
|
||||||
|
@@ -47,9 +47,9 @@ inline void traverse(isaac::math_expression const & math_expression, std::size_t
|
|||||||
//Lhs:
|
//Lhs:
|
||||||
if (recurse)
|
if (recurse)
|
||||||
{
|
{
|
||||||
if (root_node.lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||||
traverse(math_expression, root_node.lhs.node_index, fun, inspect);
|
traverse(math_expression, root_node.lhs.node_index, fun, inspect);
|
||||||
if (root_node.lhs.type_family != INVALID_TYPE_FAMILY)
|
if (root_node.lhs.subtype != INVALID_SUBTYPE)
|
||||||
fun(math_expression, root_idx, LHS_NODE_TYPE);
|
fun(math_expression, root_idx, LHS_NODE_TYPE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,11 +58,11 @@ inline void traverse(isaac::math_expression const & math_expression, std::size_t
|
|||||||
fun(math_expression, root_idx, PARENT_NODE_TYPE);
|
fun(math_expression, root_idx, PARENT_NODE_TYPE);
|
||||||
|
|
||||||
//Rhs:
|
//Rhs:
|
||||||
if (recurse && root_node.rhs.type_family!=INVALID_TYPE_FAMILY)
|
if (recurse && root_node.rhs.subtype!=INVALID_SUBTYPE)
|
||||||
{
|
{
|
||||||
if (root_node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||||
traverse(math_expression, root_node.rhs.node_index, fun, inspect);
|
traverse(math_expression, root_node.rhs.node_index, fun, inspect);
|
||||||
if (root_node.rhs.type_family != INVALID_TYPE_FAMILY)
|
if (root_node.rhs.subtype != INVALID_SUBTYPE)
|
||||||
fun(math_expression, root_idx, RHS_NODE_TYPE);
|
fun(math_expression, root_idx, RHS_NODE_TYPE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,10 +84,10 @@ private:
|
|||||||
class filter_elements_fun : public traversal_functor
|
class filter_elements_fun : public traversal_functor
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
filter_elements_fun(math_expression_node_subtype subtype, std::vector<lhs_rhs_element> & out);
|
filter_elements_fun(node_type subtype, std::vector<lhs_rhs_element> & out);
|
||||||
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t) const;
|
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t) const;
|
||||||
private:
|
private:
|
||||||
math_expression_node_subtype subtype_;
|
node_type subtype_;
|
||||||
std::vector<lhs_rhs_element> & out_;
|
std::vector<lhs_rhs_element> & out_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -96,9 +96,9 @@ std::vector<size_t> filter_nodes(bool (*pred)(math_expression::node const & node
|
|||||||
size_t root,
|
size_t root,
|
||||||
bool inspect);
|
bool inspect);
|
||||||
|
|
||||||
std::vector<lhs_rhs_element> filter_elements(math_expression_node_subtype subtype,
|
std::vector<lhs_rhs_element> filter_elements(node_type subtype,
|
||||||
isaac::math_expression const & math_expression);
|
isaac::math_expression const & math_expression);
|
||||||
const char * evaluate(operation_node_type type);
|
const char * evaluate(operation_type type);
|
||||||
|
|
||||||
/** @brief functor for generating the expression string from a math_expression */
|
/** @brief functor for generating the expression string from a math_expression */
|
||||||
class evaluate_expression_traversal: public traversal_functor
|
class evaluate_expression_traversal: public traversal_functor
|
||||||
|
@@ -23,140 +23,121 @@ namespace isaac
|
|||||||
class array_base;
|
class array_base;
|
||||||
|
|
||||||
/** @brief Optimization enum for grouping operations into unary or binary operations. Just for optimization of lookups. */
|
/** @brief Optimization enum for grouping operations into unary or binary operations. Just for optimization of lookups. */
|
||||||
enum operation_node_type_family
|
enum operation_type_family
|
||||||
{
|
{
|
||||||
OPERATOR_INVALID_TYPE_FAMILY = 0,
|
INVALID_TYPE_FAMILY = 0,
|
||||||
|
|
||||||
// BLAS1-type
|
// BLAS1-type
|
||||||
OPERATOR_UNARY_TYPE_FAMILY,
|
UNARY_TYPE_FAMILY,
|
||||||
OPERATOR_BINARY_TYPE_FAMILY,
|
BINARY_TYPE_FAMILY,
|
||||||
OPERATOR_VECTOR_DOT_TYPE_FAMILY,
|
VECTOR_DOT_TYPE_FAMILY,
|
||||||
|
|
||||||
// BLAS2-type
|
// BLAS2-type
|
||||||
OPERATOR_ROWS_DOT_TYPE_FAMILY,
|
ROWS_DOT_TYPE_FAMILY,
|
||||||
OPERATOR_COLUMNS_DOT_TYPE_FAMILY,
|
COLUMNS_DOT_TYPE_FAMILY,
|
||||||
|
|
||||||
// BLAS3-type
|
// BLAS3-type
|
||||||
OPERATOR_GEMM_TYPE_FAMILY
|
MATRIX_PRODUCT_TYPE_FAMILY
|
||||||
};
|
};
|
||||||
|
|
||||||
/** @brief Enumeration for identifying the possible operations */
|
/** @brief Enumeration for identifying the possible operations */
|
||||||
enum operation_node_type
|
enum operation_type
|
||||||
{
|
{
|
||||||
OPERATOR_INVALID_TYPE = 0,
|
INVALID_TYPE = 0,
|
||||||
|
|
||||||
// unary operator
|
// unary operator
|
||||||
OPERATOR_MINUS_TYPE,
|
MINUS_TYPE,
|
||||||
OPERATOR_NEGATE_TYPE,
|
NEGATE_TYPE,
|
||||||
|
|
||||||
// unary expression
|
// unary expression
|
||||||
OPERATOR_CAST_BOOL_TYPE,
|
CAST_BOOL_TYPE,
|
||||||
OPERATOR_CAST_CHAR_TYPE,
|
CAST_CHAR_TYPE,
|
||||||
OPERATOR_CAST_UCHAR_TYPE,
|
CAST_UCHAR_TYPE,
|
||||||
OPERATOR_CAST_SHORT_TYPE,
|
CAST_SHORT_TYPE,
|
||||||
OPERATOR_CAST_USHORT_TYPE,
|
CAST_USHORT_TYPE,
|
||||||
OPERATOR_CAST_INT_TYPE,
|
CAST_INT_TYPE,
|
||||||
OPERATOR_CAST_UINT_TYPE,
|
CAST_UINT_TYPE,
|
||||||
OPERATOR_CAST_LONG_TYPE,
|
CAST_LONG_TYPE,
|
||||||
OPERATOR_CAST_ULONG_TYPE,
|
CAST_ULONG_TYPE,
|
||||||
OPERATOR_CAST_HALF_TYPE,
|
CAST_HALF_TYPE,
|
||||||
OPERATOR_CAST_FLOAT_TYPE,
|
CAST_FLOAT_TYPE,
|
||||||
OPERATOR_CAST_DOUBLE_TYPE,
|
CAST_DOUBLE_TYPE,
|
||||||
|
|
||||||
OPERATOR_ABS_TYPE,
|
ABS_TYPE,
|
||||||
OPERATOR_ACOS_TYPE,
|
ACOS_TYPE,
|
||||||
OPERATOR_ASIN_TYPE,
|
ASIN_TYPE,
|
||||||
OPERATOR_ATAN_TYPE,
|
ATAN_TYPE,
|
||||||
OPERATOR_CEIL_TYPE,
|
CEIL_TYPE,
|
||||||
OPERATOR_COS_TYPE,
|
COS_TYPE,
|
||||||
OPERATOR_COSH_TYPE,
|
COSH_TYPE,
|
||||||
OPERATOR_EXP_TYPE,
|
EXP_TYPE,
|
||||||
OPERATOR_FABS_TYPE,
|
FABS_TYPE,
|
||||||
OPERATOR_FLOOR_TYPE,
|
FLOOR_TYPE,
|
||||||
OPERATOR_LOG_TYPE,
|
LOG_TYPE,
|
||||||
OPERATOR_LOG10_TYPE,
|
LOG10_TYPE,
|
||||||
OPERATOR_SIN_TYPE,
|
SIN_TYPE,
|
||||||
OPERATOR_SINH_TYPE,
|
SINH_TYPE,
|
||||||
OPERATOR_SQRT_TYPE,
|
SQRT_TYPE,
|
||||||
OPERATOR_TAN_TYPE,
|
TAN_TYPE,
|
||||||
OPERATOR_TANH_TYPE,
|
TANH_TYPE,
|
||||||
OPERATOR_TRANS_TYPE,
|
TRANS_TYPE,
|
||||||
|
|
||||||
// binary expression
|
// binary expression
|
||||||
OPERATOR_ASSIGN_TYPE,
|
ASSIGN_TYPE,
|
||||||
OPERATOR_INPLACE_ADD_TYPE,
|
INPLACE_ADD_TYPE,
|
||||||
OPERATOR_INPLACE_SUB_TYPE,
|
INPLACE_SUB_TYPE,
|
||||||
OPERATOR_ADD_TYPE,
|
ADD_TYPE,
|
||||||
OPERATOR_SUB_TYPE,
|
SUB_TYPE,
|
||||||
OPERATOR_MULT_TYPE,
|
MULT_TYPE,
|
||||||
OPERATOR_DIV_TYPE,
|
DIV_TYPE,
|
||||||
OPERATOR_ELEMENT_ARGFMAX_TYPE,
|
ELEMENT_ARGFMAX_TYPE,
|
||||||
OPERATOR_ELEMENT_ARGFMIN_TYPE,
|
ELEMENT_ARGFMIN_TYPE,
|
||||||
OPERATOR_ELEMENT_ARGMAX_TYPE,
|
ELEMENT_ARGMAX_TYPE,
|
||||||
OPERATOR_ELEMENT_ARGMIN_TYPE,
|
ELEMENT_ARGMIN_TYPE,
|
||||||
OPERATOR_ELEMENT_PROD_TYPE,
|
ELEMENT_PROD_TYPE,
|
||||||
OPERATOR_ELEMENT_DIV_TYPE,
|
ELEMENT_DIV_TYPE,
|
||||||
OPERATOR_ELEMENT_EQ_TYPE,
|
ELEMENT_EQ_TYPE,
|
||||||
OPERATOR_ELEMENT_NEQ_TYPE,
|
ELEMENT_NEQ_TYPE,
|
||||||
OPERATOR_ELEMENT_GREATER_TYPE,
|
ELEMENT_GREATER_TYPE,
|
||||||
OPERATOR_ELEMENT_GEQ_TYPE,
|
ELEMENT_GEQ_TYPE,
|
||||||
OPERATOR_ELEMENT_LESS_TYPE,
|
ELEMENT_LESS_TYPE,
|
||||||
OPERATOR_ELEMENT_LEQ_TYPE,
|
ELEMENT_LEQ_TYPE,
|
||||||
OPERATOR_ELEMENT_POW_TYPE,
|
ELEMENT_POW_TYPE,
|
||||||
OPERATOR_ELEMENT_FMAX_TYPE,
|
ELEMENT_FMAX_TYPE,
|
||||||
OPERATOR_ELEMENT_FMIN_TYPE,
|
ELEMENT_FMIN_TYPE,
|
||||||
OPERATOR_ELEMENT_MAX_TYPE,
|
ELEMENT_MAX_TYPE,
|
||||||
OPERATOR_ELEMENT_MIN_TYPE,
|
ELEMENT_MIN_TYPE,
|
||||||
|
|
||||||
//Products
|
//Products
|
||||||
OPERATOR_OUTER_PROD_TYPE,
|
OUTER_PROD_TYPE,
|
||||||
OPERATOR_GEMM_NN_TYPE,
|
MATRIX_PRODUCT_NN_TYPE,
|
||||||
OPERATOR_GEMM_TN_TYPE,
|
MATRIX_PRODUCT_TN_TYPE,
|
||||||
OPERATOR_GEMM_NT_TYPE,
|
MATRIX_PRODUCT_NT_TYPE,
|
||||||
OPERATOR_GEMM_TT_TYPE,
|
MATRIX_PRODUCT_TT_TYPE,
|
||||||
|
|
||||||
//Access modifiers
|
//Access modifiers
|
||||||
OPERATOR_MATRIX_DIAG_TYPE,
|
MATRIX_DIAG_TYPE,
|
||||||
OPERATOR_MATRIX_ROW_TYPE,
|
MATRIX_ROW_TYPE,
|
||||||
OPERATOR_MATRIX_COLUMN_TYPE,
|
MATRIX_COLUMN_TYPE,
|
||||||
OPERATOR_REPEAT_TYPE,
|
REPEAT_TYPE,
|
||||||
OPERATOR_RESHAPE_TYPE,
|
RESHAPE_TYPE,
|
||||||
OPERATOR_SHIFT_TYPE,
|
SHIFT_TYPE,
|
||||||
OPERATOR_VDIAG_TYPE,
|
VDIAG_TYPE,
|
||||||
OPERATOR_ACCESS_INDEX_TYPE,
|
ACCESS_INDEX_TYPE,
|
||||||
|
|
||||||
|
|
||||||
OPERATOR_PAIR_TYPE,
|
PAIR_TYPE,
|
||||||
|
|
||||||
OPERATOR_FUSE,
|
OPERATOR_FUSE,
|
||||||
OPERATOR_SFOR_TYPE,
|
SFOR_TYPE,
|
||||||
};
|
|
||||||
|
|
||||||
/** @brief Groups the type of a node in the math_expression tree. Used for faster dispatching */
|
|
||||||
enum math_expression_node_type_family
|
|
||||||
{
|
|
||||||
INVALID_TYPE_FAMILY = 0,
|
|
||||||
COMPOSITE_OPERATOR_FAMILY,
|
|
||||||
VALUE_TYPE_FAMILY,
|
|
||||||
ARRAY_TYPE_FAMILY,
|
|
||||||
PLACEHOLDER_TYPE_FAMILY
|
|
||||||
};
|
|
||||||
|
|
||||||
/** @brief Encodes the type of a node in the math_expression tree. */
|
|
||||||
enum math_expression_node_subtype
|
|
||||||
{
|
|
||||||
INVALID_SUBTYPE = 0,
|
|
||||||
VALUE_SCALAR_TYPE,
|
|
||||||
DENSE_ARRAY_TYPE,
|
|
||||||
FOR_LOOP_INDEX_TYPE
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct op_element
|
struct op_element
|
||||||
{
|
{
|
||||||
op_element();
|
op_element();
|
||||||
op_element(operation_node_type_family const & _type_family, operation_node_type const & _type);
|
op_element(operation_type_family const & _type_family, operation_type const & _type);
|
||||||
operation_node_type_family type_family;
|
operation_type_family type_family;
|
||||||
operation_node_type type;
|
operation_type type;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct for_idx_t
|
struct for_idx_t
|
||||||
@@ -172,11 +153,19 @@ struct for_idx_t
|
|||||||
int level;
|
int level;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum node_type
|
||||||
|
{
|
||||||
|
INVALID_SUBTYPE = 0,
|
||||||
|
COMPOSITE_OPERATOR_TYPE,
|
||||||
|
VALUE_SCALAR_TYPE,
|
||||||
|
DENSE_ARRAY_TYPE,
|
||||||
|
FOR_LOOP_INDEX_TYPE
|
||||||
|
};
|
||||||
|
|
||||||
struct lhs_rhs_element
|
struct lhs_rhs_element
|
||||||
{
|
{
|
||||||
lhs_rhs_element();
|
lhs_rhs_element();
|
||||||
math_expression_node_type_family type_family;
|
node_type subtype;
|
||||||
math_expression_node_subtype subtype;
|
|
||||||
numeric_type dtype;
|
numeric_type dtype;
|
||||||
union
|
union
|
||||||
{
|
{
|
||||||
|
@@ -7,7 +7,7 @@
|
|||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
|
||||||
std::string to_string(math_expression_node_subtype const & f);
|
std::string to_string(node_type const & f);
|
||||||
std::string to_string(lhs_rhs_element const & e);
|
std::string to_string(lhs_rhs_element const & e);
|
||||||
std::ostream & operator<<(std::ostream & os, math_expression::node const & s_node);
|
std::ostream & operator<<(std::ostream & os, math_expression::node const & s_node);
|
||||||
std::string to_string(isaac::math_expression const & s);
|
std::string to_string(isaac::math_expression const & s);
|
||||||
|
@@ -19,13 +19,13 @@ ISAACAPI typename std::conditional<std::is_arithmetic<T>::value, value_scalar, T
|
|||||||
|
|
||||||
template<typename T, typename... Args>
|
template<typename T, typename... Args>
|
||||||
ISAACAPI math_expression make_tuple(driver::Context const & context, T const & x, Args... args)
|
ISAACAPI math_expression make_tuple(driver::Context const & context, T const & x, Args... args)
|
||||||
{ return math_expression(wrap_generic(x), make_tuple(context, args...), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_PAIR_TYPE), context, numeric_type_of(x), {1}); }
|
{ return math_expression(wrap_generic(x), make_tuple(context, args...), op_element(BINARY_TYPE_FAMILY, PAIR_TYPE), context, numeric_type_of(x), {1}); }
|
||||||
|
|
||||||
inline value_scalar tuple_get(math_expression::container_type const & tree, size_t root, size_t idx)
|
inline value_scalar tuple_get(math_expression::container_type const & tree, size_t root, size_t idx)
|
||||||
{
|
{
|
||||||
for(unsigned int i = 0 ; i < idx ; ++i){
|
for(unsigned int i = 0 ; i < idx ; ++i){
|
||||||
math_expression::node node = tree[root];
|
math_expression::node node = tree[root];
|
||||||
if(node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
if(node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||||
root = node.rhs.node_index;
|
root = node.rhs.node_index;
|
||||||
else
|
else
|
||||||
return value_scalar(node.rhs.vscalar, node.rhs.dtype);
|
return value_scalar(node.rhs.vscalar, node.rhs.dtype);
|
||||||
|
252
lib/array.cpp
252
lib/array.cpp
@@ -169,7 +169,7 @@ array_base & array_base::operator=(array_base const & rhs)
|
|||||||
{
|
{
|
||||||
if(shape_.min()==0) return *this;
|
if(shape_.min()==0) return *this;
|
||||||
assert(dtype_ == rhs.dtype());
|
assert(dtype_ == rhs.dtype());
|
||||||
math_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
|
math_expression expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||||
execute(execution_handler(expression));
|
execute(execution_handler(expression));
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@@ -178,7 +178,7 @@ array_base & array_base::operator=(value_scalar const & rhs)
|
|||||||
{
|
{
|
||||||
if(shape_.min()==0) return *this;
|
if(shape_.min()==0) return *this;
|
||||||
assert(dtype_ == rhs.dtype());
|
assert(dtype_ == rhs.dtype());
|
||||||
math_expression expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
|
math_expression expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||||
execute(execution_handler(expression));
|
execute(execution_handler(expression));
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@@ -188,7 +188,7 @@ array_base& array_base::operator=(execution_handler const & c)
|
|||||||
{
|
{
|
||||||
if(shape_.min()==0) return *this;
|
if(shape_.min()==0) return *this;
|
||||||
assert(dtype_ == c.x().dtype());
|
assert(dtype_ == c.x().dtype());
|
||||||
math_expression expression(*this, c.x(), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ASSIGN_TYPE), context_, dtype_, shape_);
|
math_expression expression(*this, c.x(), op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||||
execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@@ -228,53 +228,53 @@ INSTANTIATE(double);
|
|||||||
|
|
||||||
|
|
||||||
math_expression array_base::operator-()
|
math_expression array_base::operator-()
|
||||||
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
|
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||||
|
|
||||||
math_expression array_base::operator!()
|
math_expression array_base::operator!()
|
||||||
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_NEGATE_TYPE), context_, INT_TYPE, shape_); }
|
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), context_, INT_TYPE, shape_); }
|
||||||
|
|
||||||
//
|
//
|
||||||
array_base & array_base::operator+=(value_scalar const & rhs)
|
array_base & array_base::operator+=(value_scalar const & rhs)
|
||||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
|
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
|
||||||
|
|
||||||
array_base & array_base::operator+=(array_base const & rhs)
|
array_base & array_base::operator+=(array_base const & rhs)
|
||||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
|
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
|
||||||
|
|
||||||
array_base & array_base::operator+=(math_expression const & rhs)
|
array_base & array_base::operator+=(math_expression const & rhs)
|
||||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), rhs.context(), dtype_, shape_); }
|
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), rhs.context(), dtype_, shape_); }
|
||||||
//----
|
//----
|
||||||
array_base & array_base::operator-=(value_scalar const & rhs)
|
array_base & array_base::operator-=(value_scalar const & rhs)
|
||||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
|
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||||
|
|
||||||
array_base & array_base::operator-=(array_base const & rhs)
|
array_base & array_base::operator-=(array_base const & rhs)
|
||||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
|
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||||
|
|
||||||
array_base & array_base::operator-=(math_expression const & rhs)
|
array_base & array_base::operator-=(math_expression const & rhs)
|
||||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), rhs.context(), dtype_, shape_); }
|
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), rhs.context(), dtype_, shape_); }
|
||||||
//----
|
//----
|
||||||
array_base & array_base::operator*=(value_scalar const & rhs)
|
array_base & array_base::operator*=(value_scalar const & rhs)
|
||||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
|
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
|
||||||
|
|
||||||
array_base & array_base::operator*=(array_base const & rhs)
|
array_base & array_base::operator*=(array_base const & rhs)
|
||||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
|
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
|
||||||
|
|
||||||
array_base & array_base::operator*=(math_expression const & rhs)
|
array_base & array_base::operator*=(math_expression const & rhs)
|
||||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), rhs.context(), dtype_, shape_); }
|
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), rhs.context(), dtype_, shape_); }
|
||||||
//----
|
//----
|
||||||
array_base & array_base::operator/=(value_scalar const & rhs)
|
array_base & array_base::operator/=(value_scalar const & rhs)
|
||||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
|
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
|
||||||
|
|
||||||
array_base & array_base::operator/=(array_base const & rhs)
|
array_base & array_base::operator/=(array_base const & rhs)
|
||||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
|
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
|
||||||
|
|
||||||
array_base & array_base::operator/=(math_expression const & rhs)
|
array_base & array_base::operator/=(math_expression const & rhs)
|
||||||
{ return *this = math_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), rhs.context(), dtype_, shape_); }
|
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), rhs.context(), dtype_, shape_); }
|
||||||
|
|
||||||
/*--- Indexing operators -----*/
|
/*--- Indexing operators -----*/
|
||||||
//---------------------------------------
|
//---------------------------------------
|
||||||
math_expression array_base::operator[](for_idx_t idx) const
|
math_expression array_base::operator[](for_idx_t idx) const
|
||||||
{
|
{
|
||||||
return math_expression(*this, idx, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ACCESS_INDEX_TYPE), context_, dtype_, {1});
|
return math_expression(*this, idx, op_element(BINARY_TYPE_FAMILY, ACCESS_INDEX_TYPE), context_, dtype_, {1});
|
||||||
}
|
}
|
||||||
|
|
||||||
scalar array_base::operator [](int_t idx)
|
scalar array_base::operator [](int_t idx)
|
||||||
@@ -512,72 +512,72 @@ shape_t broadcast(shape_t const & a, shape_t const & b)
|
|||||||
|
|
||||||
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \
|
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \
|
||||||
math_expression OPNAME (array_base const & x, math_expression const & y) \
|
math_expression OPNAME (array_base const & x, math_expression const & y) \
|
||||||
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||||
\
|
\
|
||||||
math_expression OPNAME (array_base const & x, array_base const & y) \
|
math_expression OPNAME (array_base const & x, array_base const & y) \
|
||||||
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); }\
|
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); }\
|
||||||
\
|
\
|
||||||
math_expression OPNAME (array_base const & x, value_scalar const & y) \
|
math_expression OPNAME (array_base const & x, value_scalar const & y) \
|
||||||
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||||
\
|
\
|
||||||
math_expression OPNAME (array_base const & x, for_idx_t const & y) \
|
math_expression OPNAME (array_base const & x, for_idx_t const & y) \
|
||||||
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||||
\
|
\
|
||||||
\
|
\
|
||||||
math_expression OPNAME (math_expression const & x, math_expression const & y) \
|
math_expression OPNAME (math_expression const & x, math_expression const & y) \
|
||||||
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||||
\
|
\
|
||||||
math_expression OPNAME (math_expression const & x, array_base const & y) \
|
math_expression OPNAME (math_expression const & x, array_base const & y) \
|
||||||
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||||
\
|
\
|
||||||
math_expression OPNAME (math_expression const & x, value_scalar const & y) \
|
math_expression OPNAME (math_expression const & x, value_scalar const & y) \
|
||||||
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||||
\
|
\
|
||||||
math_expression OPNAME (math_expression const & x, for_idx_t const & y) \
|
math_expression OPNAME (math_expression const & x, for_idx_t const & y) \
|
||||||
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||||
\
|
\
|
||||||
\
|
\
|
||||||
math_expression OPNAME (value_scalar const & y, math_expression const & x) \
|
math_expression OPNAME (value_scalar const & y, math_expression const & x) \
|
||||||
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||||
\
|
\
|
||||||
math_expression OPNAME (value_scalar const & y, array_base const & x) \
|
math_expression OPNAME (value_scalar const & y, array_base const & x) \
|
||||||
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||||
\
|
\
|
||||||
math_expression OPNAME (value_scalar const & x, for_idx_t const & y) \
|
math_expression OPNAME (value_scalar const & x, for_idx_t const & y) \
|
||||||
{ return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE); }\
|
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); }\
|
||||||
\
|
\
|
||||||
\
|
\
|
||||||
math_expression OPNAME (for_idx_t const & y, math_expression const & x) \
|
math_expression OPNAME (for_idx_t const & y, math_expression const & x) \
|
||||||
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||||
\
|
\
|
||||||
math_expression OPNAME (for_idx_t const & y, value_scalar const & x) \
|
math_expression OPNAME (for_idx_t const & y, value_scalar const & x) \
|
||||||
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), DTYPE); } \
|
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); } \
|
||||||
\
|
\
|
||||||
math_expression OPNAME (for_idx_t const & y, array_base const & x) \
|
math_expression OPNAME (for_idx_t const & y, array_base const & x) \
|
||||||
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||||
\
|
\
|
||||||
math_expression OPNAME (for_idx_t const & y, for_idx_t const & x) \
|
math_expression OPNAME (for_idx_t const & y, for_idx_t const & x) \
|
||||||
{ return math_expression(y, x, op_element(OPERATOR_BINARY_TYPE_FAMILY, OP)); }
|
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP)); }
|
||||||
|
|
||||||
|
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ADD_TYPE, operator +, x.dtype())
|
DEFINE_ELEMENT_BINARY_OPERATOR(ADD_TYPE, operator +, x.dtype())
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_SUB_TYPE, operator -, x.dtype())
|
DEFINE_ELEMENT_BINARY_OPERATOR(SUB_TYPE, operator -, x.dtype())
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_MULT_TYPE, operator *, x.dtype())
|
DEFINE_ELEMENT_BINARY_OPERATOR(MULT_TYPE, operator *, x.dtype())
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_DIV_TYPE, operator /, x.dtype())
|
DEFINE_ELEMENT_BINARY_OPERATOR(DIV_TYPE, operator /, x.dtype())
|
||||||
|
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, maximum, x.dtype())
|
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_MAX_TYPE, maximum, x.dtype())
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, minimum, x.dtype())
|
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_MIN_TYPE, minimum, x.dtype())
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow, x.dtype())
|
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_POW_TYPE, pow, x.dtype())
|
||||||
|
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ASSIGN_TYPE, assign, x.dtype())
|
DEFINE_ELEMENT_BINARY_OPERATOR(ASSIGN_TYPE, assign, x.dtype())
|
||||||
|
|
||||||
|
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GREATER_TYPE, operator >, INT_TYPE)
|
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_GREATER_TYPE, operator >, INT_TYPE)
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_GEQ_TYPE, operator >=, INT_TYPE)
|
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_GEQ_TYPE, operator >=, INT_TYPE)
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LESS_TYPE, operator <, INT_TYPE)
|
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_LESS_TYPE, operator <, INT_TYPE)
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LEQ_TYPE, operator <=, INT_TYPE)
|
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_LEQ_TYPE, operator <=, INT_TYPE)
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==, INT_TYPE)
|
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_EQ_TYPE, operator ==, INT_TYPE)
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=, INT_TYPE)
|
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_NEQ_TYPE, operator !=, INT_TYPE)
|
||||||
|
|
||||||
#define DEFINE_OUTER(LTYPE, RTYPE) \
|
#define DEFINE_OUTER(LTYPE, RTYPE) \
|
||||||
math_expression outer(LTYPE const & x, RTYPE const & y)\
|
math_expression outer(LTYPE const & x, RTYPE const & y)\
|
||||||
@@ -585,7 +585,7 @@ math_expression outer(LTYPE const & x, RTYPE const & y)\
|
|||||||
assert(x.dim()<=1 && y.dim()<=1);\
|
assert(x.dim()<=1 && y.dim()<=1);\
|
||||||
if(x.dim()<1 || y.dim()<1)\
|
if(x.dim()<1 || y.dim()<1)\
|
||||||
return x*y;\
|
return x*y;\
|
||||||
return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_OUTER_PROD_TYPE), x.context(), x.dtype(), {x.shape().max(), y.shape().max()} );\
|
return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OUTER_PROD_TYPE), x.context(), x.dtype(), {x.shape().max(), y.shape().max()} );\
|
||||||
}\
|
}\
|
||||||
|
|
||||||
DEFINE_OUTER(array_base, array_base)
|
DEFINE_OUTER(array_base, array_base)
|
||||||
@@ -622,61 +622,61 @@ DEFINE_ROT(math_expression, math_expression, math_expression, math_expression)
|
|||||||
//---------------------------------------
|
//---------------------------------------
|
||||||
#define DEFINE_ELEMENT_UNARY_OPERATOR(OP, OPNAME) \
|
#define DEFINE_ELEMENT_UNARY_OPERATOR(OP, OPNAME) \
|
||||||
math_expression OPNAME (array_base const & x) \
|
math_expression OPNAME (array_base const & x) \
|
||||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
|
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
|
||||||
\
|
\
|
||||||
math_expression OPNAME (math_expression const & x) \
|
math_expression OPNAME (math_expression const & x) \
|
||||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
|
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
|
||||||
|
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?OPERATOR_FABS_TYPE:OPERATOR_ABS_TYPE, abs)
|
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?FABS_TYPE:ABS_TYPE, abs)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ACOS_TYPE, acos)
|
DEFINE_ELEMENT_UNARY_OPERATOR(ACOS_TYPE, acos)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ASIN_TYPE, asin)
|
DEFINE_ELEMENT_UNARY_OPERATOR(ASIN_TYPE, asin)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_ATAN_TYPE, atan)
|
DEFINE_ELEMENT_UNARY_OPERATOR(ATAN_TYPE, atan)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_CEIL_TYPE, ceil)
|
DEFINE_ELEMENT_UNARY_OPERATOR(CEIL_TYPE, ceil)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_COS_TYPE, cos)
|
DEFINE_ELEMENT_UNARY_OPERATOR(COS_TYPE, cos)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_COSH_TYPE, cosh)
|
DEFINE_ELEMENT_UNARY_OPERATOR(COSH_TYPE, cosh)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_EXP_TYPE, exp)
|
DEFINE_ELEMENT_UNARY_OPERATOR(EXP_TYPE, exp)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_FLOOR_TYPE, floor)
|
DEFINE_ELEMENT_UNARY_OPERATOR(FLOOR_TYPE, floor)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_LOG_TYPE, log)
|
DEFINE_ELEMENT_UNARY_OPERATOR(LOG_TYPE, log)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_LOG10_TYPE,log10)
|
DEFINE_ELEMENT_UNARY_OPERATOR(LOG10_TYPE,log10)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_SIN_TYPE, sin)
|
DEFINE_ELEMENT_UNARY_OPERATOR(SIN_TYPE, sin)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_SINH_TYPE, sinh)
|
DEFINE_ELEMENT_UNARY_OPERATOR(SINH_TYPE, sinh)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_SQRT_TYPE, sqrt)
|
DEFINE_ELEMENT_UNARY_OPERATOR(SQRT_TYPE, sqrt)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_TAN_TYPE, tan)
|
DEFINE_ELEMENT_UNARY_OPERATOR(TAN_TYPE, tan)
|
||||||
DEFINE_ELEMENT_UNARY_OPERATOR(OPERATOR_TANH_TYPE, tanh)
|
DEFINE_ELEMENT_UNARY_OPERATOR(TANH_TYPE, tanh)
|
||||||
#undef DEFINE_ELEMENT_UNARY_OPERATOR
|
#undef DEFINE_ELEMENT_UNARY_OPERATOR
|
||||||
//---------------------------------------
|
//---------------------------------------
|
||||||
|
|
||||||
|
|
||||||
///*--- Misc----*/
|
///*--- Misc----*/
|
||||||
////---------------------------------------
|
////---------------------------------------
|
||||||
inline operation_node_type casted(numeric_type dtype)
|
inline operation_type casted(numeric_type dtype)
|
||||||
{
|
{
|
||||||
switch(dtype)
|
switch(dtype)
|
||||||
{
|
{
|
||||||
// case BOOL_TYPE: return OPERATOR_CAST_BOOL_TYPE;
|
// case BOOL_TYPE: return CAST_BOOL_TYPE;
|
||||||
case CHAR_TYPE: return OPERATOR_CAST_CHAR_TYPE;
|
case CHAR_TYPE: return CAST_CHAR_TYPE;
|
||||||
case UCHAR_TYPE: return OPERATOR_CAST_UCHAR_TYPE;
|
case UCHAR_TYPE: return CAST_UCHAR_TYPE;
|
||||||
case SHORT_TYPE: return OPERATOR_CAST_SHORT_TYPE;
|
case SHORT_TYPE: return CAST_SHORT_TYPE;
|
||||||
case USHORT_TYPE: return OPERATOR_CAST_USHORT_TYPE;
|
case USHORT_TYPE: return CAST_USHORT_TYPE;
|
||||||
case INT_TYPE: return OPERATOR_CAST_INT_TYPE;
|
case INT_TYPE: return CAST_INT_TYPE;
|
||||||
case UINT_TYPE: return OPERATOR_CAST_UINT_TYPE;
|
case UINT_TYPE: return CAST_UINT_TYPE;
|
||||||
case LONG_TYPE: return OPERATOR_CAST_LONG_TYPE;
|
case LONG_TYPE: return CAST_LONG_TYPE;
|
||||||
case ULONG_TYPE: return OPERATOR_CAST_ULONG_TYPE;
|
case ULONG_TYPE: return CAST_ULONG_TYPE;
|
||||||
// case FLOAT_TYPE: return OPERATOR_CAST_HALF_TYPE;
|
// case FLOAT_TYPE: return CAST_HALF_TYPE;
|
||||||
case FLOAT_TYPE: return OPERATOR_CAST_FLOAT_TYPE;
|
case FLOAT_TYPE: return CAST_FLOAT_TYPE;
|
||||||
case DOUBLE_TYPE: return OPERATOR_CAST_DOUBLE_TYPE;
|
case DOUBLE_TYPE: return CAST_DOUBLE_TYPE;
|
||||||
default: throw unknown_datatype(dtype);
|
default: throw unknown_datatype(dtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
math_expression cast(array_base const & x, numeric_type dtype)
|
math_expression cast(array_base const & x, numeric_type dtype)
|
||||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||||
|
|
||||||
math_expression cast(math_expression const & x, numeric_type dtype)
|
math_expression cast(math_expression const & x, numeric_type dtype)
|
||||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||||
|
|
||||||
isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||||
{ return math_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, {M, N}); }
|
{ return math_expression(value_scalar(1), value_scalar(0), op_element(UNARY_TYPE_FAMILY, VDIAG_TYPE), ctx, dtype, {M, N}); }
|
||||||
|
|
||||||
array diag(array_base & x, int offset)
|
array diag(array_base & x, int offset)
|
||||||
{
|
{
|
||||||
@@ -689,7 +689,7 @@ array diag(array_base & x, int offset)
|
|||||||
|
|
||||||
|
|
||||||
isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||||
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, {M, N}); }
|
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(UNARY_TYPE_FAMILY, ADD_TYPE), ctx, dtype, {M, N}); }
|
||||||
|
|
||||||
inline shape_t flip(shape_t const & shape)
|
inline shape_t flip(shape_t const & shape)
|
||||||
{
|
{
|
||||||
@@ -703,28 +703,28 @@ inline shape_t flip(shape_t const & shape)
|
|||||||
//{ return size4(shape1[0]*shape2[0], shape1[1]*shape2[1]);}
|
//{ return size4(shape1[0]*shape2[0], shape1[1]*shape2[1]);}
|
||||||
|
|
||||||
math_expression trans(array_base const & x) \
|
math_expression trans(array_base const & x) \
|
||||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
|
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
|
||||||
\
|
\
|
||||||
math_expression trans(math_expression const & x) \
|
math_expression trans(math_expression const & x) \
|
||||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
|
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
|
||||||
|
|
||||||
math_expression repmat(array_base const & A, int_t const & rep1, int_t const & rep2)
|
math_expression repmat(array_base const & A, int_t const & rep1, int_t const & rep2)
|
||||||
{
|
{
|
||||||
int_t sub1 = A.shape()[0];
|
int_t sub1 = A.shape()[0];
|
||||||
int_t sub2 = A.dim()==2?A.shape()[1]:1;
|
int_t sub2 = A.dim()==2?A.shape()[1]:1;
|
||||||
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
|
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
|
||||||
}
|
}
|
||||||
|
|
||||||
math_expression repmat(math_expression const & A, int_t const & rep1, int_t const & rep2)
|
math_expression repmat(math_expression const & A, int_t const & rep1, int_t const & rep2)
|
||||||
{
|
{
|
||||||
int_t sub1 = A.shape()[0];
|
int_t sub1 = A.shape()[0];
|
||||||
int_t sub2 = A.dim()==2?A.shape()[1]:1;
|
int_t sub2 = A.dim()==2?A.shape()[1]:1;
|
||||||
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
|
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DEFINE_ACCESS_ROW(TYPEA, TYPEB) \
|
#define DEFINE_ACCESS_ROW(TYPEA, TYPEB) \
|
||||||
math_expression row(TYPEA const & x, TYPEB const & i)\
|
math_expression row(TYPEA const & x, TYPEB const & i)\
|
||||||
{ return math_expression(x, i, op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_MATRIX_ROW_TYPE), x.context(), x.dtype(), {x.shape()[1]}); }
|
{ return math_expression(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_ROW_TYPE), x.context(), x.dtype(), {x.shape()[1]}); }
|
||||||
|
|
||||||
DEFINE_ACCESS_ROW(array_base, value_scalar)
|
DEFINE_ACCESS_ROW(array_base, value_scalar)
|
||||||
DEFINE_ACCESS_ROW(array_base, for_idx_t)
|
DEFINE_ACCESS_ROW(array_base, for_idx_t)
|
||||||
@@ -736,7 +736,7 @@ DEFINE_ACCESS_ROW(math_expression, math_expression)
|
|||||||
|
|
||||||
#define DEFINE_ACCESS_COL(TYPEA, TYPEB) \
|
#define DEFINE_ACCESS_COL(TYPEA, TYPEB) \
|
||||||
math_expression col(TYPEA const & x, TYPEB const & i)\
|
math_expression col(TYPEA const & x, TYPEB const & i)\
|
||||||
{ return math_expression(x, i, op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_MATRIX_COLUMN_TYPE), x.context(), x.dtype(), {x.shape()[0]}); }
|
{ return math_expression(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_COLUMN_TYPE), x.context(), x.dtype(), {x.shape()[0]}); }
|
||||||
|
|
||||||
DEFINE_ACCESS_COL(array_base, value_scalar)
|
DEFINE_ACCESS_COL(array_base, value_scalar)
|
||||||
DEFINE_ACCESS_COL(array_base, for_idx_t)
|
DEFINE_ACCESS_COL(array_base, for_idx_t)
|
||||||
@@ -756,11 +756,11 @@ math_expression OPNAME(array_base const & x, int_t axis)\
|
|||||||
if(axis < -1 || axis > x.dim())\
|
if(axis < -1 || axis > x.dim())\
|
||||||
throw std::out_of_range("The axis entry is out of bounds");\
|
throw std::out_of_range("The axis entry is out of bounds");\
|
||||||
else if(axis==-1)\
|
else if(axis==-1)\
|
||||||
return math_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
|
return math_expression(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
|
||||||
else if(axis==0)\
|
else if(axis==0)\
|
||||||
return math_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
|
return math_expression(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
|
||||||
else\
|
else\
|
||||||
return math_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
|
return math_expression(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
|
||||||
}\
|
}\
|
||||||
\
|
\
|
||||||
math_expression OPNAME(math_expression const & x, int_t axis)\
|
math_expression OPNAME(math_expression const & x, int_t axis)\
|
||||||
@@ -768,18 +768,18 @@ math_expression OPNAME(math_expression const & x, int_t axis)\
|
|||||||
if(axis < -1 || axis > x.dim())\
|
if(axis < -1 || axis > x.dim())\
|
||||||
throw std::out_of_range("The axis entry is out of bounds");\
|
throw std::out_of_range("The axis entry is out of bounds");\
|
||||||
if(axis==-1)\
|
if(axis==-1)\
|
||||||
return math_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
|
return math_expression(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
|
||||||
else if(axis==0)\
|
else if(axis==0)\
|
||||||
return math_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
|
return math_expression(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
|
||||||
else\
|
else\
|
||||||
return math_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
|
return math_expression(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFINE_REDUCTION(OPERATOR_ADD_TYPE, sum)
|
DEFINE_REDUCTION(ADD_TYPE, sum)
|
||||||
DEFINE_REDUCTION(OPERATOR_ELEMENT_ARGMAX_TYPE, argmax)
|
DEFINE_REDUCTION(ELEMENT_ARGMAX_TYPE, argmax)
|
||||||
DEFINE_REDUCTION(OPERATOR_ELEMENT_MAX_TYPE, max)
|
DEFINE_REDUCTION(ELEMENT_MAX_TYPE, max)
|
||||||
DEFINE_REDUCTION(OPERATOR_ELEMENT_MIN_TYPE, min)
|
DEFINE_REDUCTION(ELEMENT_MIN_TYPE, min)
|
||||||
DEFINE_REDUCTION(OPERATOR_ELEMENT_ARGMIN_TYPE, argmin)
|
DEFINE_REDUCTION(ELEMENT_ARGMIN_TYPE, argmin)
|
||||||
|
|
||||||
#undef DEFINE_REDUCTION
|
#undef DEFINE_REDUCTION
|
||||||
|
|
||||||
@@ -789,21 +789,21 @@ namespace detail
|
|||||||
math_expression matmatprod(array_base const & A, array_base const & B)
|
math_expression matmatprod(array_base const & A, array_base const & B)
|
||||||
{
|
{
|
||||||
shape_t shape{A.shape()[0], B.shape()[1]};
|
shape_t shape{A.shape()[0], B.shape()[1]};
|
||||||
return math_expression(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, OPERATOR_GEMM_NN_TYPE), A.context(), A.dtype(), shape);
|
return math_expression(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, MATRIX_PRODUCT_NN_TYPE), A.context(), A.dtype(), shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
math_expression matmatprod(math_expression const & A, array_base const & B)
|
math_expression matmatprod(math_expression const & A, array_base const & B)
|
||||||
{
|
{
|
||||||
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
|
operation_type type = MATRIX_PRODUCT_NN_TYPE;
|
||||||
shape_t shape{A.shape()[0], B.shape()[1]};
|
shape_t shape{A.shape()[0], B.shape()[1]};
|
||||||
|
|
||||||
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
||||||
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
|
bool A_trans = A_root.op.type==TRANS_TYPE;
|
||||||
if(A_trans){
|
if(A_trans){
|
||||||
type = OPERATOR_GEMM_TN_TYPE;
|
type = MATRIX_PRODUCT_TN_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||||
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
||||||
if(A_trans) res_root.lhs = A_root.lhs;
|
if(A_trans) res_root.lhs = A_root.lhs;
|
||||||
return res;
|
return res;
|
||||||
@@ -811,16 +811,16 @@ namespace detail
|
|||||||
|
|
||||||
math_expression matmatprod(array_base const & A, math_expression const & B)
|
math_expression matmatprod(array_base const & A, math_expression const & B)
|
||||||
{
|
{
|
||||||
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
|
operation_type type = MATRIX_PRODUCT_NN_TYPE;
|
||||||
shape_t shape{A.shape()[0], B.shape()[1]};
|
shape_t shape{A.shape()[0], B.shape()[1]};
|
||||||
|
|
||||||
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
|
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
|
||||||
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
|
bool B_trans = B_root.op.type==TRANS_TYPE;
|
||||||
if(B_trans){
|
if(B_trans){
|
||||||
type = OPERATOR_GEMM_NT_TYPE;
|
type = MATRIX_PRODUCT_NT_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||||
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
||||||
if(B_trans) res_root.rhs = B_root.lhs;
|
if(B_trans) res_root.rhs = B_root.lhs;
|
||||||
return res;
|
return res;
|
||||||
@@ -828,20 +828,20 @@ namespace detail
|
|||||||
|
|
||||||
math_expression matmatprod(math_expression const & A, math_expression const & B)
|
math_expression matmatprod(math_expression const & A, math_expression const & B)
|
||||||
{
|
{
|
||||||
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
|
operation_type type = MATRIX_PRODUCT_NN_TYPE;
|
||||||
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
||||||
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
|
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
|
||||||
shape_t shape{A.shape()[0], B.shape()[1]};
|
shape_t shape{A.shape()[0], B.shape()[1]};
|
||||||
|
|
||||||
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
|
bool A_trans = A_root.op.type==TRANS_TYPE;
|
||||||
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
|
bool B_trans = B_root.op.type==TRANS_TYPE;
|
||||||
|
|
||||||
if(A_trans && B_trans) type = OPERATOR_GEMM_TT_TYPE;
|
if(A_trans && B_trans) type = MATRIX_PRODUCT_TT_TYPE;
|
||||||
else if(A_trans && !B_trans) type = OPERATOR_GEMM_TN_TYPE;
|
else if(A_trans && !B_trans) type = MATRIX_PRODUCT_TN_TYPE;
|
||||||
else if(!A_trans && B_trans) type = OPERATOR_GEMM_NT_TYPE;
|
else if(!A_trans && B_trans) type = MATRIX_PRODUCT_NT_TYPE;
|
||||||
else type = OPERATOR_GEMM_NN_TYPE;
|
else type = MATRIX_PRODUCT_NN_TYPE;
|
||||||
|
|
||||||
math_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||||
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
||||||
if(A_trans) res_root.lhs = A_root.lhs;
|
if(A_trans) res_root.lhs = A_root.lhs;
|
||||||
if(B_trans) res_root.rhs = B_root.lhs;
|
if(B_trans) res_root.rhs = B_root.lhs;
|
||||||
@@ -862,14 +862,14 @@ namespace detail
|
|||||||
int_t M = A.shape()[0];
|
int_t M = A.shape()[0];
|
||||||
int_t N = A.shape()[1];
|
int_t N = A.shape()[1];
|
||||||
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
||||||
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
|
bool A_trans = A_root.op.type==TRANS_TYPE;
|
||||||
while(A_root.lhs.type_family==COMPOSITE_OPERATOR_FAMILY){
|
while(A_root.lhs.subtype==COMPOSITE_OPERATOR_TYPE){
|
||||||
A_root = A.tree()[A_root.lhs.node_index];
|
A_root = A.tree()[A_root.lhs.node_index];
|
||||||
A_trans ^= A_root.op.type==OPERATOR_TRANS_TYPE;
|
A_trans ^= A_root.op.type==TRANS_TYPE;
|
||||||
}
|
}
|
||||||
if(A_trans)
|
if(A_trans)
|
||||||
{
|
{
|
||||||
math_expression tmp(A, repmat(x, 1, M), op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ELEMENT_PROD_TYPE), A.context(), A.dtype(), {N, M});
|
math_expression tmp(A, repmat(x, 1, M), op_element(BINARY_TYPE_FAMILY, ELEMENT_PROD_TYPE), A.context(), A.dtype(), {N, M});
|
||||||
//Remove trans
|
//Remove trans
|
||||||
tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs;
|
tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs;
|
||||||
return sum(tmp, 0);
|
return sum(tmp, 0);
|
||||||
@@ -890,10 +890,10 @@ ISAACAPI void swap(view x, view y)
|
|||||||
|
|
||||||
//Reshape
|
//Reshape
|
||||||
math_expression reshape(array_base const & x, shape_t const & shape)
|
math_expression reshape(array_base const & x, shape_t const & shape)
|
||||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_RESHAPE_TYPE), x.context(), x.dtype(), shape); }
|
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
|
||||||
|
|
||||||
math_expression reshape(math_expression const & x, shape_t const & shape)
|
math_expression reshape(math_expression const & x, shape_t const & shape)
|
||||||
{ return math_expression(x, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_RESHAPE_TYPE), x.context(), x.dtype(), shape); }
|
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
|
||||||
|
|
||||||
math_expression ravel(array_base const & x)
|
math_expression ravel(array_base const & x)
|
||||||
{ return reshape(x, {x.shape().prod()}); }
|
{ return reshape(x, {x.shape().prod()}); }
|
||||||
@@ -908,7 +908,7 @@ math_expression dot(LTYPE const & x, RTYPE const & y)\
|
|||||||
if(x.dim()==2 && x.shape()[1]==0)\
|
if(x.dim()==2 && x.shape()[1]==0)\
|
||||||
return zeros(x.shape()[0], y.shape()[1], dtype, context);\
|
return zeros(x.shape()[0], y.shape()[1], dtype, context);\
|
||||||
if(x.shape()[0]==0 || (y.dim()==2 && y.shape()[1]==0))\
|
if(x.shape()[0]==0 || (y.dim()==2 && y.shape()[1]==0))\
|
||||||
return math_expression(invalid_node(), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_INVALID_TYPE), context, dtype, {0});\
|
return math_expression(invalid_node(), invalid_node(), op_element(UNARY_TYPE_FAMILY, INVALID_TYPE), context, dtype, {0});\
|
||||||
if(x.dim()==1 && y.dim()==1)\
|
if(x.dim()==1 && y.dim()==1)\
|
||||||
return sum(x*y);\
|
return sum(x*y);\
|
||||||
if(x.dim()==2 && x.shape()[0]==1 && y.dim()==1){\
|
if(x.dim()==2 && x.shape()[0]==1 && y.dim()==1){\
|
||||||
@@ -967,13 +967,13 @@ DEFINE_NORM(math_expression)
|
|||||||
math_expression fuse(math_expression const & x, math_expression const & y)
|
math_expression fuse(math_expression const & x, math_expression const & y)
|
||||||
{
|
{
|
||||||
assert(x.context()==y.context());
|
assert(x.context()==y.context());
|
||||||
return math_expression(x, y, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape());
|
return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
/*--- For loops ---*/
|
/*--- For loops ---*/
|
||||||
ISAACAPI math_expression sfor(math_expression const & start, math_expression const & end, math_expression const & inc, math_expression const & x)
|
ISAACAPI math_expression sfor(math_expression const & start, math_expression const & end, math_expression const & inc, math_expression const & x)
|
||||||
{
|
{
|
||||||
return math_expression(x, make_tuple(x.context(), start, end, inc), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SFOR_TYPE), x.context(), x.dtype(), x.shape());
|
return math_expression(x, make_tuple(x.context(), start, end, inc), op_element(UNARY_TYPE_FAMILY, SFOR_TYPE), x.context(), x.dtype(), x.shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -96,7 +96,7 @@ mapped_object& get(math_expression::container_type const & tree, size_t root, ma
|
|||||||
{
|
{
|
||||||
for(unsigned int i = 0 ; i < idx ; ++i){
|
for(unsigned int i = 0 ; i < idx ; ++i){
|
||||||
math_expression::node node = tree[root];
|
math_expression::node node = tree[root];
|
||||||
if(node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
if(node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||||
root = node.rhs.node_index;
|
root = node.rhs.node_index;
|
||||||
else
|
else
|
||||||
return *(mapping.at(std::make_pair(root, RHS_NODE_TYPE)));
|
return *(mapping.at(std::make_pair(root, RHS_NODE_TYPE)));
|
||||||
@@ -136,10 +136,10 @@ math_expression::node mapped_reduce::root_node() const
|
|||||||
bool mapped_reduce::is_index_reduction() const
|
bool mapped_reduce::is_index_reduction() const
|
||||||
{
|
{
|
||||||
op_element const & op = root_op();
|
op_element const & op = root_op();
|
||||||
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
|
return op.type==ELEMENT_ARGFMAX_TYPE
|
||||||
|| op.type==OPERATOR_ELEMENT_ARGMAX_TYPE
|
|| op.type==ELEMENT_ARGMAX_TYPE
|
||||||
|| op.type==OPERATOR_ELEMENT_ARGFMIN_TYPE
|
|| op.type==ELEMENT_ARGFMIN_TYPE
|
||||||
|| op.type==OPERATOR_ELEMENT_ARGMIN_TYPE;
|
|| op.type==ELEMENT_ARGMIN_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
op_element mapped_reduce::root_op() const
|
op_element mapped_reduce::root_op() const
|
||||||
@@ -358,27 +358,27 @@ void mapped_outer::postprocess(std::string &res) const
|
|||||||
mapped_outer::mapped_outer(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "outer"), binary_leaf(info)
|
mapped_outer::mapped_outer(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "outer"), binary_leaf(info)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
std::string mapped_cast::operator_to_str(operation_node_type type)
|
std::string mapped_cast::operator_to_str(operation_type type)
|
||||||
{
|
{
|
||||||
switch(type)
|
switch(type)
|
||||||
{
|
{
|
||||||
case OPERATOR_CAST_BOOL_TYPE : return "bool";
|
case CAST_BOOL_TYPE : return "bool";
|
||||||
case OPERATOR_CAST_CHAR_TYPE : return "char";
|
case CAST_CHAR_TYPE : return "char";
|
||||||
case OPERATOR_CAST_UCHAR_TYPE : return "uchar";
|
case CAST_UCHAR_TYPE : return "uchar";
|
||||||
case OPERATOR_CAST_SHORT_TYPE : return "short";
|
case CAST_SHORT_TYPE : return "short";
|
||||||
case OPERATOR_CAST_USHORT_TYPE : return "ushort";
|
case CAST_USHORT_TYPE : return "ushort";
|
||||||
case OPERATOR_CAST_INT_TYPE : return "int";
|
case CAST_INT_TYPE : return "int";
|
||||||
case OPERATOR_CAST_UINT_TYPE : return "uint";
|
case CAST_UINT_TYPE : return "uint";
|
||||||
case OPERATOR_CAST_LONG_TYPE : return "long";
|
case CAST_LONG_TYPE : return "long";
|
||||||
case OPERATOR_CAST_ULONG_TYPE : return "ulong";
|
case CAST_ULONG_TYPE : return "ulong";
|
||||||
case OPERATOR_CAST_HALF_TYPE : return "half";
|
case CAST_HALF_TYPE : return "half";
|
||||||
case OPERATOR_CAST_FLOAT_TYPE : return "float";
|
case CAST_FLOAT_TYPE : return "float";
|
||||||
case OPERATOR_CAST_DOUBLE_TYPE : return "double";
|
case CAST_DOUBLE_TYPE : return "double";
|
||||||
default : return "invalid";
|
default : return "invalid";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mapped_cast::mapped_cast(operation_node_type type, unsigned int id) : mapped_object(operator_to_str(type), id, "cast")
|
mapped_cast::mapped_cast(operation_type type, unsigned int id) : mapped_object(operator_to_str(type), id, "cast")
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
|
|
||||||
|
@@ -16,103 +16,103 @@ namespace detail
|
|||||||
|
|
||||||
bool is_scalar_reduce_1d(math_expression::node const & node)
|
bool is_scalar_reduce_1d(math_expression::node const & node)
|
||||||
{
|
{
|
||||||
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY;
|
return node.op.type_family==VECTOR_DOT_TYPE_FAMILY;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_vector_reduce_1d(math_expression::node const & node)
|
bool is_vector_reduce_1d(math_expression::node const & node)
|
||||||
{
|
{
|
||||||
return node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|
return node.op.type_family==ROWS_DOT_TYPE_FAMILY
|
||||||
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY;
|
|| node.op.type_family==COLUMNS_DOT_TYPE_FAMILY;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_assignment(op_element const & op)
|
bool is_assignment(op_element const & op)
|
||||||
{
|
{
|
||||||
return op.type== OPERATOR_ASSIGN_TYPE
|
return op.type== ASSIGN_TYPE
|
||||||
|| op.type== OPERATOR_INPLACE_ADD_TYPE
|
|| op.type== INPLACE_ADD_TYPE
|
||||||
|| op.type== OPERATOR_INPLACE_SUB_TYPE;
|
|| op.type== INPLACE_SUB_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_elementwise_operator(op_element const & op)
|
bool is_elementwise_operator(op_element const & op)
|
||||||
{
|
{
|
||||||
return is_assignment(op)
|
return is_assignment(op)
|
||||||
|| op.type== OPERATOR_ADD_TYPE
|
|| op.type== ADD_TYPE
|
||||||
|| op.type== OPERATOR_SUB_TYPE
|
|| op.type== SUB_TYPE
|
||||||
|| op.type== OPERATOR_ELEMENT_PROD_TYPE
|
|| op.type== ELEMENT_PROD_TYPE
|
||||||
|| op.type== OPERATOR_ELEMENT_DIV_TYPE
|
|| op.type== ELEMENT_DIV_TYPE
|
||||||
|| op.type== OPERATOR_MULT_TYPE
|
|| op.type== MULT_TYPE
|
||||||
|| op.type== OPERATOR_DIV_TYPE
|
|| op.type== DIV_TYPE
|
||||||
|| op.type== OPERATOR_ELEMENT_EQ_TYPE
|
|| op.type== ELEMENT_EQ_TYPE
|
||||||
|| op.type== OPERATOR_ELEMENT_NEQ_TYPE
|
|| op.type== ELEMENT_NEQ_TYPE
|
||||||
|| op.type== OPERATOR_ELEMENT_GREATER_TYPE
|
|| op.type== ELEMENT_GREATER_TYPE
|
||||||
|| op.type== OPERATOR_ELEMENT_LESS_TYPE
|
|| op.type== ELEMENT_LESS_TYPE
|
||||||
|| op.type== OPERATOR_ELEMENT_GEQ_TYPE
|
|| op.type== ELEMENT_GEQ_TYPE
|
||||||
|| op.type== OPERATOR_ELEMENT_LEQ_TYPE ;
|
|| op.type== ELEMENT_LEQ_TYPE ;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bypass(op_element const & op)
|
bool bypass(op_element const & op)
|
||||||
{
|
{
|
||||||
return op.type == OPERATOR_RESHAPE_TYPE
|
return op.type == RESHAPE_TYPE
|
||||||
||op.type == OPERATOR_TRANS_TYPE;
|
||op.type == TRANS_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_cast(op_element const & op)
|
bool is_cast(op_element const & op)
|
||||||
{
|
{
|
||||||
return op.type== OPERATOR_CAST_BOOL_TYPE
|
return op.type== CAST_BOOL_TYPE
|
||||||
|| op.type== OPERATOR_CAST_CHAR_TYPE
|
|| op.type== CAST_CHAR_TYPE
|
||||||
|| op.type== OPERATOR_CAST_UCHAR_TYPE
|
|| op.type== CAST_UCHAR_TYPE
|
||||||
|| op.type== OPERATOR_CAST_SHORT_TYPE
|
|| op.type== CAST_SHORT_TYPE
|
||||||
|| op.type== OPERATOR_CAST_USHORT_TYPE
|
|| op.type== CAST_USHORT_TYPE
|
||||||
|| op.type== OPERATOR_CAST_INT_TYPE
|
|| op.type== CAST_INT_TYPE
|
||||||
|| op.type== OPERATOR_CAST_UINT_TYPE
|
|| op.type== CAST_UINT_TYPE
|
||||||
|| op.type== OPERATOR_CAST_LONG_TYPE
|
|| op.type== CAST_LONG_TYPE
|
||||||
|| op.type== OPERATOR_CAST_ULONG_TYPE
|
|| op.type== CAST_ULONG_TYPE
|
||||||
|| op.type== OPERATOR_CAST_FLOAT_TYPE
|
|| op.type== CAST_FLOAT_TYPE
|
||||||
|| op.type== OPERATOR_CAST_DOUBLE_TYPE
|
|| op.type== CAST_DOUBLE_TYPE
|
||||||
;
|
;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_node_leaf(op_element const & op)
|
bool is_node_leaf(op_element const & op)
|
||||||
{
|
{
|
||||||
return op.type==OPERATOR_MATRIX_DIAG_TYPE
|
return op.type==MATRIX_DIAG_TYPE
|
||||||
|| op.type==OPERATOR_VDIAG_TYPE
|
|| op.type==VDIAG_TYPE
|
||||||
|| op.type==OPERATOR_REPEAT_TYPE
|
|| op.type==REPEAT_TYPE
|
||||||
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|
|| op.type==MATRIX_ROW_TYPE
|
||||||
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|
|| op.type==MATRIX_COLUMN_TYPE
|
||||||
|| op.type==OPERATOR_ACCESS_INDEX_TYPE
|
|| op.type==ACCESS_INDEX_TYPE
|
||||||
|| op.type==OPERATOR_OUTER_PROD_TYPE
|
|| op.type==OUTER_PROD_TYPE
|
||||||
|| op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|
|| op.type_family==VECTOR_DOT_TYPE_FAMILY
|
||||||
|| op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|
|| op.type_family==ROWS_DOT_TYPE_FAMILY
|
||||||
|| op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|
|| op.type_family==COLUMNS_DOT_TYPE_FAMILY
|
||||||
|| op.type_family==OPERATOR_GEMM_TYPE_FAMILY
|
|| op.type_family==MATRIX_PRODUCT_TYPE_FAMILY
|
||||||
;
|
;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_elementwise_function(op_element const & op)
|
bool is_elementwise_function(op_element const & op)
|
||||||
{
|
{
|
||||||
return is_cast(op)
|
return is_cast(op)
|
||||||
|| op.type== OPERATOR_ABS_TYPE
|
|| op.type== ABS_TYPE
|
||||||
|| op.type== OPERATOR_ACOS_TYPE
|
|| op.type== ACOS_TYPE
|
||||||
|| op.type== OPERATOR_ASIN_TYPE
|
|| op.type== ASIN_TYPE
|
||||||
|| op.type== OPERATOR_ATAN_TYPE
|
|| op.type== ATAN_TYPE
|
||||||
|| op.type== OPERATOR_CEIL_TYPE
|
|| op.type== CEIL_TYPE
|
||||||
|| op.type== OPERATOR_COS_TYPE
|
|| op.type== COS_TYPE
|
||||||
|| op.type== OPERATOR_COSH_TYPE
|
|| op.type== COSH_TYPE
|
||||||
|| op.type== OPERATOR_EXP_TYPE
|
|| op.type== EXP_TYPE
|
||||||
|| op.type== OPERATOR_FABS_TYPE
|
|| op.type== FABS_TYPE
|
||||||
|| op.type== OPERATOR_FLOOR_TYPE
|
|| op.type== FLOOR_TYPE
|
||||||
|| op.type== OPERATOR_LOG_TYPE
|
|| op.type== LOG_TYPE
|
||||||
|| op.type== OPERATOR_LOG10_TYPE
|
|| op.type== LOG10_TYPE
|
||||||
|| op.type== OPERATOR_SIN_TYPE
|
|| op.type== SIN_TYPE
|
||||||
|| op.type== OPERATOR_SINH_TYPE
|
|| op.type== SINH_TYPE
|
||||||
|| op.type== OPERATOR_SQRT_TYPE
|
|| op.type== SQRT_TYPE
|
||||||
|| op.type== OPERATOR_TAN_TYPE
|
|| op.type== TAN_TYPE
|
||||||
|| op.type== OPERATOR_TANH_TYPE
|
|| op.type== TANH_TYPE
|
||||||
|
|
||||||
|| op.type== OPERATOR_ELEMENT_POW_TYPE
|
|| op.type== ELEMENT_POW_TYPE
|
||||||
|| op.type== OPERATOR_ELEMENT_FMAX_TYPE
|
|| op.type== ELEMENT_FMAX_TYPE
|
||||||
|| op.type== OPERATOR_ELEMENT_FMIN_TYPE
|
|| op.type== ELEMENT_FMIN_TYPE
|
||||||
|| op.type== OPERATOR_ELEMENT_MAX_TYPE
|
|| op.type== ELEMENT_MAX_TYPE
|
||||||
|| op.type== OPERATOR_ELEMENT_MIN_TYPE;
|
|| op.type== ELEMENT_MIN_TYPE;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,7 +139,7 @@ std::vector<size_t> filter_nodes(bool (*pred)(math_expression::node const & node
|
|||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
filter_elements_fun::filter_elements_fun(math_expression_node_subtype subtype, std::vector<lhs_rhs_element> & out) :
|
filter_elements_fun::filter_elements_fun(node_type subtype, std::vector<lhs_rhs_element> & out) :
|
||||||
subtype_(subtype), out_(out)
|
subtype_(subtype), out_(out)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
@@ -153,84 +153,84 @@ void filter_elements_fun::operator()(isaac::math_expression const & math_express
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
std::vector<lhs_rhs_element> filter_elements(math_expression_node_subtype subtype, isaac::math_expression const & math_expression)
|
std::vector<lhs_rhs_element> filter_elements(node_type subtype, isaac::math_expression const & math_expression)
|
||||||
{
|
{
|
||||||
std::vector<lhs_rhs_element> res;
|
std::vector<lhs_rhs_element> res;
|
||||||
traverse(math_expression, math_expression.root(), filter_elements_fun(subtype, res), true);
|
traverse(math_expression, math_expression.root(), filter_elements_fun(subtype, res), true);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** @brief generate a string from an operation_node_type */
|
/** @brief generate a string from an operation_type */
|
||||||
const char * evaluate(operation_node_type type)
|
const char * evaluate(operation_type type)
|
||||||
{
|
{
|
||||||
// unary expression
|
// unary expression
|
||||||
switch (type)
|
switch (type)
|
||||||
{
|
{
|
||||||
//Function
|
//Function
|
||||||
case OPERATOR_ABS_TYPE : return "abs";
|
case ABS_TYPE : return "abs";
|
||||||
case OPERATOR_ACOS_TYPE : return "acos";
|
case ACOS_TYPE : return "acos";
|
||||||
case OPERATOR_ASIN_TYPE : return "asin";
|
case ASIN_TYPE : return "asin";
|
||||||
case OPERATOR_ATAN_TYPE : return "atan";
|
case ATAN_TYPE : return "atan";
|
||||||
case OPERATOR_CEIL_TYPE : return "ceil";
|
case CEIL_TYPE : return "ceil";
|
||||||
case OPERATOR_COS_TYPE : return "cos";
|
case COS_TYPE : return "cos";
|
||||||
case OPERATOR_COSH_TYPE : return "cosh";
|
case COSH_TYPE : return "cosh";
|
||||||
case OPERATOR_EXP_TYPE : return "exp";
|
case EXP_TYPE : return "exp";
|
||||||
case OPERATOR_FABS_TYPE : return "fabs";
|
case FABS_TYPE : return "fabs";
|
||||||
case OPERATOR_FLOOR_TYPE : return "floor";
|
case FLOOR_TYPE : return "floor";
|
||||||
case OPERATOR_LOG_TYPE : return "log";
|
case LOG_TYPE : return "log";
|
||||||
case OPERATOR_LOG10_TYPE : return "log10";
|
case LOG10_TYPE : return "log10";
|
||||||
case OPERATOR_SIN_TYPE : return "sin";
|
case SIN_TYPE : return "sin";
|
||||||
case OPERATOR_SINH_TYPE : return "sinh";
|
case SINH_TYPE : return "sinh";
|
||||||
case OPERATOR_SQRT_TYPE : return "sqrt";
|
case SQRT_TYPE : return "sqrt";
|
||||||
case OPERATOR_TAN_TYPE : return "tan";
|
case TAN_TYPE : return "tan";
|
||||||
case OPERATOR_TANH_TYPE : return "tanh";
|
case TANH_TYPE : return "tanh";
|
||||||
|
|
||||||
case OPERATOR_ELEMENT_ARGFMAX_TYPE : return "argfmax";
|
case ELEMENT_ARGFMAX_TYPE : return "argfmax";
|
||||||
case OPERATOR_ELEMENT_ARGMAX_TYPE : return "argmax";
|
case ELEMENT_ARGMAX_TYPE : return "argmax";
|
||||||
case OPERATOR_ELEMENT_ARGFMIN_TYPE : return "argfmin";
|
case ELEMENT_ARGFMIN_TYPE : return "argfmin";
|
||||||
case OPERATOR_ELEMENT_ARGMIN_TYPE : return "argmin";
|
case ELEMENT_ARGMIN_TYPE : return "argmin";
|
||||||
case OPERATOR_ELEMENT_POW_TYPE : return "pow";
|
case ELEMENT_POW_TYPE : return "pow";
|
||||||
|
|
||||||
//Arithmetic
|
//Arithmetic
|
||||||
case OPERATOR_MINUS_TYPE : return "-";
|
case MINUS_TYPE : return "-";
|
||||||
case OPERATOR_ASSIGN_TYPE : return "=";
|
case ASSIGN_TYPE : return "=";
|
||||||
case OPERATOR_INPLACE_ADD_TYPE : return "+=";
|
case INPLACE_ADD_TYPE : return "+=";
|
||||||
case OPERATOR_INPLACE_SUB_TYPE : return "-=";
|
case INPLACE_SUB_TYPE : return "-=";
|
||||||
case OPERATOR_ADD_TYPE : return "+";
|
case ADD_TYPE : return "+";
|
||||||
case OPERATOR_SUB_TYPE : return "-";
|
case SUB_TYPE : return "-";
|
||||||
case OPERATOR_MULT_TYPE : return "*";
|
case MULT_TYPE : return "*";
|
||||||
case OPERATOR_ELEMENT_PROD_TYPE : return "*";
|
case ELEMENT_PROD_TYPE : return "*";
|
||||||
case OPERATOR_DIV_TYPE : return "/";
|
case DIV_TYPE : return "/";
|
||||||
case OPERATOR_ELEMENT_DIV_TYPE : return "/";
|
case ELEMENT_DIV_TYPE : return "/";
|
||||||
|
|
||||||
//Relational
|
//Relational
|
||||||
case OPERATOR_NEGATE_TYPE: return "!";
|
case NEGATE_TYPE: return "!";
|
||||||
case OPERATOR_ELEMENT_EQ_TYPE : return "==";
|
case ELEMENT_EQ_TYPE : return "==";
|
||||||
case OPERATOR_ELEMENT_NEQ_TYPE : return "!=";
|
case ELEMENT_NEQ_TYPE : return "!=";
|
||||||
case OPERATOR_ELEMENT_GREATER_TYPE : return ">";
|
case ELEMENT_GREATER_TYPE : return ">";
|
||||||
case OPERATOR_ELEMENT_GEQ_TYPE : return ">=";
|
case ELEMENT_GEQ_TYPE : return ">=";
|
||||||
case OPERATOR_ELEMENT_LESS_TYPE : return "<";
|
case ELEMENT_LESS_TYPE : return "<";
|
||||||
case OPERATOR_ELEMENT_LEQ_TYPE : return "<=";
|
case ELEMENT_LEQ_TYPE : return "<=";
|
||||||
|
|
||||||
case OPERATOR_ELEMENT_FMAX_TYPE : return "fmax";
|
case ELEMENT_FMAX_TYPE : return "fmax";
|
||||||
case OPERATOR_ELEMENT_FMIN_TYPE : return "fmin";
|
case ELEMENT_FMIN_TYPE : return "fmin";
|
||||||
case OPERATOR_ELEMENT_MAX_TYPE : return "max";
|
case ELEMENT_MAX_TYPE : return "max";
|
||||||
case OPERATOR_ELEMENT_MIN_TYPE : return "min";
|
case ELEMENT_MIN_TYPE : return "min";
|
||||||
|
|
||||||
//Binary
|
//Binary
|
||||||
case OPERATOR_GEMM_NN_TYPE : return "prodNN";
|
case MATRIX_PRODUCT_NN_TYPE : return "prodNN";
|
||||||
case OPERATOR_GEMM_TN_TYPE : return "prodTN";
|
case MATRIX_PRODUCT_TN_TYPE : return "prodTN";
|
||||||
case OPERATOR_GEMM_NT_TYPE : return "prodNT";
|
case MATRIX_PRODUCT_NT_TYPE : return "prodNT";
|
||||||
case OPERATOR_GEMM_TT_TYPE : return "prodTT";
|
case MATRIX_PRODUCT_TT_TYPE : return "prodTT";
|
||||||
case OPERATOR_VDIAG_TYPE : return "vdiag";
|
case VDIAG_TYPE : return "vdiag";
|
||||||
case OPERATOR_MATRIX_DIAG_TYPE : return "mdiag";
|
case MATRIX_DIAG_TYPE : return "mdiag";
|
||||||
case OPERATOR_MATRIX_ROW_TYPE : return "row";
|
case MATRIX_ROW_TYPE : return "row";
|
||||||
case OPERATOR_MATRIX_COLUMN_TYPE : return "col";
|
case MATRIX_COLUMN_TYPE : return "col";
|
||||||
case OPERATOR_PAIR_TYPE: return "pair";
|
case PAIR_TYPE: return "pair";
|
||||||
case OPERATOR_ACCESS_INDEX_TYPE: return "access";
|
case ACCESS_INDEX_TYPE: return "access";
|
||||||
|
|
||||||
//FOR
|
//FOR
|
||||||
case OPERATOR_SFOR_TYPE: return "sfor";
|
case SFOR_TYPE: return "sfor";
|
||||||
|
|
||||||
default : throw operation_not_supported_exception("Unsupported operator");
|
default : throw operation_not_supported_exception("Unsupported operator");
|
||||||
}
|
}
|
||||||
@@ -245,7 +245,7 @@ void evaluate_expression_traversal::call_before_expansion(isaac::math_expression
|
|||||||
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
||||||
if(detail::is_cast(root_node.op))
|
if(detail::is_cast(root_node.op))
|
||||||
str_ += mapping_.at(std::make_pair(root_idx, PARENT_NODE_TYPE))->evaluate(accessors_);
|
str_ += mapping_.at(std::make_pair(root_idx, PARENT_NODE_TYPE))->evaluate(accessors_);
|
||||||
else if (( (root_node.op.type_family==OPERATOR_UNARY_TYPE_FAMILY&&root_node.op.type!=OPERATOR_ADD_TYPE) || detail::is_elementwise_function(root_node.op))
|
else if (( (root_node.op.type_family==UNARY_TYPE_FAMILY&&root_node.op.type!=ADD_TYPE) || detail::is_elementwise_function(root_node.op))
|
||||||
&& !detail::is_node_leaf(root_node.op))
|
&& !detail::is_node_leaf(root_node.op))
|
||||||
str_+=evaluate(root_node.op.type);
|
str_+=evaluate(root_node.op.type);
|
||||||
if(root_node.op.type!=OPERATOR_FUSE)
|
if(root_node.op.type!=OPERATOR_FUSE)
|
||||||
@@ -268,7 +268,7 @@ void evaluate_expression_traversal::operator()(isaac::math_expression const & ma
|
|||||||
{
|
{
|
||||||
if (detail::is_node_leaf(root_node.op))
|
if (detail::is_node_leaf(root_node.op))
|
||||||
str_ += mapping_.at(key)->evaluate(accessors_);
|
str_ += mapping_.at(key)->evaluate(accessors_);
|
||||||
else if(root_node.op.type_family!=OPERATOR_UNARY_TYPE_FAMILY)
|
else if(root_node.op.type_family!=UNARY_TYPE_FAMILY)
|
||||||
{
|
{
|
||||||
if (detail::is_elementwise_operator(root_node.op))
|
if (detail::is_elementwise_operator(root_node.op))
|
||||||
str_ += evaluate(root_node.op.type);
|
str_ += evaluate(root_node.op.type);
|
||||||
@@ -280,7 +280,7 @@ void evaluate_expression_traversal::operator()(isaac::math_expression const & ma
|
|||||||
{
|
{
|
||||||
if (leaf==LHS_NODE_TYPE)
|
if (leaf==LHS_NODE_TYPE)
|
||||||
{
|
{
|
||||||
if (root_node.lhs.type_family!=COMPOSITE_OPERATOR_FAMILY)
|
if (root_node.lhs.subtype!=COMPOSITE_OPERATOR_TYPE)
|
||||||
{
|
{
|
||||||
if (root_node.lhs.subtype==FOR_LOOP_INDEX_TYPE)
|
if (root_node.lhs.subtype==FOR_LOOP_INDEX_TYPE)
|
||||||
str_ += "sforidx" + tools::to_string(root_node.lhs.for_idx.level);
|
str_ += "sforidx" + tools::to_string(root_node.lhs.for_idx.level);
|
||||||
@@ -291,7 +291,7 @@ void evaluate_expression_traversal::operator()(isaac::math_expression const & ma
|
|||||||
|
|
||||||
if (leaf==RHS_NODE_TYPE)
|
if (leaf==RHS_NODE_TYPE)
|
||||||
{
|
{
|
||||||
if (root_node.rhs.type_family!=COMPOSITE_OPERATOR_FAMILY)
|
if (root_node.rhs.subtype!=COMPOSITE_OPERATOR_TYPE)
|
||||||
{
|
{
|
||||||
if (root_node.rhs.subtype==FOR_LOOP_INDEX_TYPE)
|
if (root_node.rhs.subtype==FOR_LOOP_INDEX_TYPE)
|
||||||
str_ += "sforidx" + tools::to_string(root_node.rhs.for_idx.level);
|
str_ += "sforidx" + tools::to_string(root_node.rhs.for_idx.level);
|
||||||
@@ -312,14 +312,14 @@ std::string evaluate(leaf_t leaf, std::map<std::string, std::string> const & acc
|
|||||||
|
|
||||||
if (leaf==RHS_NODE_TYPE)
|
if (leaf==RHS_NODE_TYPE)
|
||||||
{
|
{
|
||||||
if (root_node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||||
traverse(math_expression, root_node.rhs.node_index, traversal_functor, false);
|
traverse(math_expression, root_node.rhs.node_index, traversal_functor, false);
|
||||||
else
|
else
|
||||||
traversal_functor(math_expression, root_idx, leaf);
|
traversal_functor(math_expression, root_idx, leaf);
|
||||||
}
|
}
|
||||||
else if (leaf==LHS_NODE_TYPE)
|
else if (leaf==LHS_NODE_TYPE)
|
||||||
{
|
{
|
||||||
if (root_node.lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||||
traverse(math_expression, root_node.lhs.node_index, traversal_functor, false);
|
traverse(math_expression, root_node.lhs.node_index, traversal_functor, false);
|
||||||
else
|
else
|
||||||
traversal_functor(math_expression, root_idx, leaf);
|
traversal_functor(math_expression, root_idx, leaf);
|
||||||
@@ -369,14 +369,14 @@ void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::strin
|
|||||||
|
|
||||||
if (leaf==RHS_NODE_TYPE)
|
if (leaf==RHS_NODE_TYPE)
|
||||||
{
|
{
|
||||||
if (root_node.rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||||
traverse(math_expression, root_node.rhs.node_index, traversal_functor, true);
|
traverse(math_expression, root_node.rhs.node_index, traversal_functor, true);
|
||||||
else
|
else
|
||||||
traversal_functor(math_expression, root_idx, leaf);
|
traversal_functor(math_expression, root_idx, leaf);
|
||||||
}
|
}
|
||||||
else if (leaf==LHS_NODE_TYPE)
|
else if (leaf==LHS_NODE_TYPE)
|
||||||
{
|
{
|
||||||
if (root_node.lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||||
traverse(math_expression, root_node.lhs.node_index, traversal_functor, true);
|
traverse(math_expression, root_node.lhs.node_index, traversal_functor, true);
|
||||||
else
|
else
|
||||||
traversal_functor(math_expression, root_idx, leaf);
|
traversal_functor(math_expression, root_idx, leaf);
|
||||||
@@ -440,9 +440,9 @@ void math_expression_representation_functor::append(char*& p, const char * str)
|
|||||||
void math_expression_representation_functor::operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf_t) const
|
void math_expression_representation_functor::operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf_t) const
|
||||||
{
|
{
|
||||||
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
||||||
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.type_family != COMPOSITE_OPERATOR_FAMILY)
|
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
|
||||||
append(root_node.lhs, detail::is_assignment(root_node.op));
|
append(root_node.lhs, detail::is_assignment(root_node.op));
|
||||||
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.type_family != COMPOSITE_OPERATOR_FAMILY)
|
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.subtype != COMPOSITE_OPERATOR_TYPE)
|
||||||
append(root_node.rhs, false);
|
append(root_node.rhs, false);
|
||||||
else if (leaf_t==PARENT_NODE_TYPE)
|
else if (leaf_t==PARENT_NODE_TYPE)
|
||||||
append_id(ptr_,root_node.op.type);
|
append_id(ptr_,root_node.op.type);
|
||||||
|
@@ -39,11 +39,11 @@ bool base::requires_fallback(math_expression const & expression)
|
|||||||
|
|
||||||
int_t base::vector_size(math_expression::node const & node)
|
int_t base::vector_size(math_expression::node const & node)
|
||||||
{
|
{
|
||||||
if (node.op.type==OPERATOR_MATRIX_DIAG_TYPE)
|
if (node.op.type==MATRIX_DIAG_TYPE)
|
||||||
return std::min<int_t>(node.lhs.array->shape()[0], node.lhs.array->shape()[1]);
|
return std::min<int_t>(node.lhs.array->shape()[0], node.lhs.array->shape()[1]);
|
||||||
else if (node.op.type==OPERATOR_MATRIX_ROW_TYPE)
|
else if (node.op.type==MATRIX_ROW_TYPE)
|
||||||
return node.lhs.array->shape()[1];
|
return node.lhs.array->shape()[1];
|
||||||
else if (node.op.type==OPERATOR_MATRIX_COLUMN_TYPE)
|
else if (node.op.type==MATRIX_COLUMN_TYPE)
|
||||||
return node.lhs.array->shape()[0];
|
return node.lhs.array->shape()[0];
|
||||||
else
|
else
|
||||||
return node.lhs.array->shape().max();
|
return node.lhs.array->shape().max();
|
||||||
@@ -52,12 +52,12 @@ int_t base::vector_size(math_expression::node const & node)
|
|||||||
|
|
||||||
std::pair<int_t, int_t> base::matrix_size(math_expression::container_type const & tree, math_expression::node const & node)
|
std::pair<int_t, int_t> base::matrix_size(math_expression::container_type const & tree, math_expression::node const & node)
|
||||||
{
|
{
|
||||||
if (node.op.type==OPERATOR_VDIAG_TYPE)
|
if (node.op.type==VDIAG_TYPE)
|
||||||
{
|
{
|
||||||
int_t size = node.lhs.array->shape()[0];
|
int_t size = node.lhs.array->shape()[0];
|
||||||
return std::make_pair(size,size);
|
return std::make_pair(size,size);
|
||||||
}
|
}
|
||||||
else if(node.op.type==OPERATOR_REPEAT_TYPE)
|
else if(node.op.type==REPEAT_TYPE)
|
||||||
{
|
{
|
||||||
size_t rep0 = tuple_get(tree, node.rhs.node_index, 0);
|
size_t rep0 = tuple_get(tree, node.rhs.node_index, 0);
|
||||||
size_t rep1 = tuple_get(tree, node.rhs.node_index, 1);
|
size_t rep1 = tuple_get(tree, node.rhs.node_index, 1);
|
||||||
|
@@ -75,7 +75,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, math_expre
|
|||||||
|
|
||||||
|
|
||||||
math_expression::container_type const & tree = expressions.tree();
|
math_expression::container_type const & tree = expressions.tree();
|
||||||
std::vector<std::size_t> sfors = filter_nodes([](math_expression::node const & node){return node.op.type==OPERATOR_SFOR_TYPE;}, expressions, expressions.root(), true);
|
std::vector<std::size_t> sfors = filter_nodes([](math_expression::node const & node){return node.op.type==SFOR_TYPE;}, expressions, expressions.root(), true);
|
||||||
|
|
||||||
for(unsigned int i = 0 ; i < sfors.size() ; ++i)
|
for(unsigned int i = 0 ; i < sfors.size() ; ++i)
|
||||||
{
|
{
|
||||||
|
@@ -81,11 +81,11 @@ public:
|
|||||||
|
|
||||||
void set_arguments(lhs_rhs_element const & lhs_rhs, bool is_assigned) const
|
void set_arguments(lhs_rhs_element const & lhs_rhs, bool is_assigned) const
|
||||||
{
|
{
|
||||||
switch(lhs_rhs.type_family)
|
switch(lhs_rhs.subtype)
|
||||||
{
|
{
|
||||||
case VALUE_TYPE_FAMILY: return set_arguments(lhs_rhs.dtype, lhs_rhs.vscalar);
|
case VALUE_SCALAR_TYPE: return set_arguments(lhs_rhs.dtype, lhs_rhs.vscalar);
|
||||||
case ARRAY_TYPE_FAMILY: return set_arguments(lhs_rhs.array, is_assigned);
|
case DENSE_ARRAY_TYPE: return set_arguments(lhs_rhs.array, is_assigned);
|
||||||
case PLACEHOLDER_TYPE_FAMILY: return;
|
case FOR_LOOP_INDEX_TYPE: return;
|
||||||
default: throw std::runtime_error("Unrecognized type family");
|
default: throw std::runtime_error("Unrecognized type family");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -93,9 +93,9 @@ public:
|
|||||||
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t leaf_t) const
|
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t leaf_t) const
|
||||||
{
|
{
|
||||||
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
||||||
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.type_family != COMPOSITE_OPERATOR_FAMILY)
|
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
|
||||||
set_arguments(root_node.lhs, detail::is_assignment(root_node.op));
|
set_arguments(root_node.lhs, detail::is_assignment(root_node.op));
|
||||||
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.type_family != COMPOSITE_OPERATOR_FAMILY)
|
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.subtype != COMPOSITE_OPERATOR_TYPE)
|
||||||
set_arguments(root_node.rhs, false);
|
set_arguments(root_node.rhs, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -44,11 +44,11 @@ class map_functor : public traversal_functor
|
|||||||
|
|
||||||
std::shared_ptr<mapped_object> create(lhs_rhs_element const & lhs_rhs, bool is_assigned = false) const
|
std::shared_ptr<mapped_object> create(lhs_rhs_element const & lhs_rhs, bool is_assigned = false) const
|
||||||
{
|
{
|
||||||
switch(lhs_rhs.type_family)
|
switch(lhs_rhs.subtype)
|
||||||
{
|
{
|
||||||
case VALUE_TYPE_FAMILY: return create(lhs_rhs.dtype, lhs_rhs.vscalar);
|
case VALUE_SCALAR_TYPE: return create(lhs_rhs.dtype, lhs_rhs.vscalar);
|
||||||
case ARRAY_TYPE_FAMILY: return create(lhs_rhs.array, is_assigned);
|
case DENSE_ARRAY_TYPE: return create(lhs_rhs.array, is_assigned);
|
||||||
case PLACEHOLDER_TYPE_FAMILY: return std::shared_ptr<mapped_object>(new mapped_placeholder(lhs_rhs.for_idx.level));
|
case FOR_LOOP_INDEX_TYPE: return std::shared_ptr<mapped_object>(new mapped_placeholder(lhs_rhs.for_idx.level));
|
||||||
default: throw "";
|
default: throw "";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -65,31 +65,31 @@ public:
|
|||||||
mapping_type::key_type key(root_idx, leaf_t);
|
mapping_type::key_type key(root_idx, leaf_t);
|
||||||
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
||||||
|
|
||||||
if (leaf_t == LHS_NODE_TYPE && root_node.lhs.type_family != COMPOSITE_OPERATOR_FAMILY)
|
if (leaf_t == LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
|
||||||
mapping_.insert(mapping_type::value_type(key, create(root_node.lhs, detail::is_assignment(root_node.op))));
|
mapping_.insert(mapping_type::value_type(key, create(root_node.lhs, detail::is_assignment(root_node.op))));
|
||||||
else if (leaf_t == RHS_NODE_TYPE && root_node.rhs.type_family != COMPOSITE_OPERATOR_FAMILY)
|
else if (leaf_t == RHS_NODE_TYPE && root_node.rhs.subtype != COMPOSITE_OPERATOR_TYPE)
|
||||||
mapping_.insert(mapping_type::value_type(key, create(root_node.rhs)));
|
mapping_.insert(mapping_type::value_type(key, create(root_node.rhs)));
|
||||||
else if ( leaf_t== PARENT_NODE_TYPE)
|
else if ( leaf_t== PARENT_NODE_TYPE)
|
||||||
{
|
{
|
||||||
if (root_node.op.type==OPERATOR_VDIAG_TYPE)
|
if (root_node.op.type==VDIAG_TYPE)
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_vdiag>(&math_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_vdiag>(&math_expression, root_idx, &mapping_)));
|
||||||
else if (root_node.op.type==OPERATOR_MATRIX_DIAG_TYPE)
|
else if (root_node.op.type==MATRIX_DIAG_TYPE)
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_diag>(&math_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_diag>(&math_expression, root_idx, &mapping_)));
|
||||||
else if (root_node.op.type==OPERATOR_MATRIX_ROW_TYPE)
|
else if (root_node.op.type==MATRIX_ROW_TYPE)
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&math_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&math_expression, root_idx, &mapping_)));
|
||||||
else if (root_node.op.type==OPERATOR_MATRIX_COLUMN_TYPE)
|
else if (root_node.op.type==MATRIX_COLUMN_TYPE)
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&math_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&math_expression, root_idx, &mapping_)));
|
||||||
else if(root_node.op.type==OPERATOR_ACCESS_INDEX_TYPE)
|
else if(root_node.op.type==ACCESS_INDEX_TYPE)
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_array_access>(&math_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_array_access>(&math_expression, root_idx, &mapping_)));
|
||||||
else if (detail::is_scalar_reduce_1d(root_node))
|
else if (detail::is_scalar_reduce_1d(root_node))
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_1d>(&math_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_1d>(&math_expression, root_idx, &mapping_)));
|
||||||
else if (detail::is_vector_reduce_1d(root_node))
|
else if (detail::is_vector_reduce_1d(root_node))
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_2d>(&math_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_2d>(&math_expression, root_idx, &mapping_)));
|
||||||
else if (root_node.op.type_family == OPERATOR_GEMM_TYPE_FAMILY)
|
else if (root_node.op.type_family == MATRIX_PRODUCT_TYPE_FAMILY)
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_product>(&math_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_product>(&math_expression, root_idx, &mapping_)));
|
||||||
else if (root_node.op.type == OPERATOR_REPEAT_TYPE)
|
else if (root_node.op.type == REPEAT_TYPE)
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&math_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&math_expression, root_idx, &mapping_)));
|
||||||
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
|
else if (root_node.op.type == OUTER_PROD_TYPE)
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&math_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&math_expression, root_idx, &mapping_)));
|
||||||
else if (detail::is_cast(root_node.op))
|
else if (detail::is_cast(root_node.op))
|
||||||
mapping_.insert(mapping_type::value_type(key, std::shared_ptr<mapped_object>(new mapped_cast(root_node.op.type, binder_.get()))));
|
mapping_.insert(mapping_type::value_type(key, std::shared_ptr<mapped_object>(new mapped_cast(root_node.op.type, binder_.get()))));
|
||||||
|
@@ -25,10 +25,10 @@ inline void compute_index_reduce_1d(kernel_generation_stream & os, std::string a
|
|||||||
// os << acc << " = " << cur_value << ">" << acc_value << "?" << cur << ":" << acc << ";" << std::endl;
|
// os << acc << " = " << cur_value << ">" << acc_value << "?" << cur << ":" << acc << ";" << std::endl;
|
||||||
os << acc << "= select(" << acc << "," << cur << "," << cur_value << ">" << acc_value << ");" << std::endl;
|
os << acc << "= select(" << acc << "," << cur << "," << cur_value << ">" << acc_value << ");" << std::endl;
|
||||||
os << acc_value << "=";
|
os << acc_value << "=";
|
||||||
if (op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE) os << "fmax";
|
if (op.type==ELEMENT_ARGFMAX_TYPE) os << "fmax";
|
||||||
if (op.type==OPERATOR_ELEMENT_ARGMAX_TYPE) os << "max";
|
if (op.type==ELEMENT_ARGMAX_TYPE) os << "max";
|
||||||
if (op.type==OPERATOR_ELEMENT_ARGFMIN_TYPE) os << "fmin";
|
if (op.type==ELEMENT_ARGFMIN_TYPE) os << "fmin";
|
||||||
if (op.type==OPERATOR_ELEMENT_ARGMIN_TYPE) os << "min";
|
if (op.type==ELEMENT_ARGMIN_TYPE) os << "min";
|
||||||
os << "(" << acc_value << "," << cur_value << ");"<< std::endl;
|
os << "(" << acc_value << "," << cur_value << ");"<< std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,17 +39,17 @@ inline std::string neutral_element(op_element const & op, driver::backend_type b
|
|||||||
|
|
||||||
switch (op.type)
|
switch (op.type)
|
||||||
{
|
{
|
||||||
case OPERATOR_ADD_TYPE : return "0";
|
case ADD_TYPE : return "0";
|
||||||
case OPERATOR_MULT_TYPE : return "1";
|
case MULT_TYPE : return "1";
|
||||||
case OPERATOR_DIV_TYPE : return "1";
|
case DIV_TYPE : return "1";
|
||||||
case OPERATOR_ELEMENT_FMAX_TYPE : return N_INF;
|
case ELEMENT_FMAX_TYPE : return N_INF;
|
||||||
case OPERATOR_ELEMENT_ARGFMAX_TYPE : return N_INF;
|
case ELEMENT_ARGFMAX_TYPE : return N_INF;
|
||||||
case OPERATOR_ELEMENT_MAX_TYPE : return N_INF;
|
case ELEMENT_MAX_TYPE : return N_INF;
|
||||||
case OPERATOR_ELEMENT_ARGMAX_TYPE : return N_INF;
|
case ELEMENT_ARGMAX_TYPE : return N_INF;
|
||||||
case OPERATOR_ELEMENT_FMIN_TYPE : return INF;
|
case ELEMENT_FMIN_TYPE : return INF;
|
||||||
case OPERATOR_ELEMENT_ARGFMIN_TYPE : return INF;
|
case ELEMENT_ARGFMIN_TYPE : return INF;
|
||||||
case OPERATOR_ELEMENT_MIN_TYPE : return INF;
|
case ELEMENT_MIN_TYPE : return INF;
|
||||||
case OPERATOR_ELEMENT_ARGMIN_TYPE : return INF;
|
case ELEMENT_ARGMIN_TYPE : return INF;
|
||||||
|
|
||||||
default: throw std::runtime_error("Unsupported reduce_1d operator : no neutral element known");
|
default: throw std::runtime_error("Unsupported reduce_1d operator : no neutral element known");
|
||||||
}
|
}
|
||||||
@@ -57,18 +57,18 @@ inline std::string neutral_element(op_element const & op, driver::backend_type b
|
|||||||
|
|
||||||
inline bool is_reduce_1d(math_expression::node const & node)
|
inline bool is_reduce_1d(math_expression::node const & node)
|
||||||
{
|
{
|
||||||
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|
return node.op.type_family==VECTOR_DOT_TYPE_FAMILY
|
||||||
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|
|| node.op.type_family==COLUMNS_DOT_TYPE_FAMILY
|
||||||
|| node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY;
|
|| node.op.type_family==ROWS_DOT_TYPE_FAMILY;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
inline bool is_index_reduction(op_element const & op)
|
inline bool is_index_reduction(op_element const & op)
|
||||||
{
|
{
|
||||||
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
|
return op.type==ELEMENT_ARGFMAX_TYPE
|
||||||
|| op.type==OPERATOR_ELEMENT_ARGMAX_TYPE
|
|| op.type==ELEMENT_ARGMAX_TYPE
|
||||||
|| op.type==OPERATOR_ELEMENT_ARGFMIN_TYPE
|
|| op.type==ELEMENT_ARGFMIN_TYPE
|
||||||
|| op.type==OPERATOR_ELEMENT_ARGMIN_TYPE;
|
|| op.type==ELEMENT_ARGMIN_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -32,27 +32,27 @@ namespace isaac
|
|||||||
bool result = false;
|
bool result = false;
|
||||||
switch(op.type_family)
|
switch(op.type_family)
|
||||||
{
|
{
|
||||||
case OPERATOR_UNARY_TYPE_FAMILY:
|
case UNARY_TYPE_FAMILY:
|
||||||
case OPERATOR_BINARY_TYPE_FAMILY:
|
case BINARY_TYPE_FAMILY:
|
||||||
result |= is_mmprod(expression)
|
result |= is_mmprod(expression)
|
||||||
|| (result |= expression==REDUCE_2D_ROWS && other==REDUCE_2D_COLS)
|
|| (result |= expression==REDUCE_2D_ROWS && other==REDUCE_2D_COLS)
|
||||||
|| (result |= expression==REDUCE_2D_COLS && other==REDUCE_2D_ROWS);
|
|| (result |= expression==REDUCE_2D_COLS && other==REDUCE_2D_ROWS);
|
||||||
break;
|
break;
|
||||||
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
|
case VECTOR_DOT_TYPE_FAMILY:
|
||||||
result |= is_mvprod(expression)
|
result |= is_mvprod(expression)
|
||||||
|| expression==REDUCE_1D;
|
|| expression==REDUCE_1D;
|
||||||
break;
|
break;
|
||||||
case OPERATOR_ROWS_DOT_TYPE_FAMILY:
|
case ROWS_DOT_TYPE_FAMILY:
|
||||||
result |= is_mmprod(expression)
|
result |= is_mmprod(expression)
|
||||||
|| is_mvprod(expression)
|
|| is_mvprod(expression)
|
||||||
|| expression==REDUCE_1D;
|
|| expression==REDUCE_1D;
|
||||||
break;
|
break;
|
||||||
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
|
case COLUMNS_DOT_TYPE_FAMILY:
|
||||||
result |= is_mmprod(expression)
|
result |= is_mmprod(expression)
|
||||||
|| is_mvprod(expression)
|
|| is_mvprod(expression)
|
||||||
|| expression==REDUCE_1D;
|
|| expression==REDUCE_1D;
|
||||||
break;
|
break;
|
||||||
case OPERATOR_GEMM_TYPE_FAMILY:
|
case MATRIX_PRODUCT_TYPE_FAMILY:
|
||||||
result |= (is_mmprod(expression) && !is_first)
|
result |= (is_mmprod(expression) && !is_first)
|
||||||
|| is_mvprod(expression)
|
|| is_mvprod(expression)
|
||||||
|| expression==REDUCE_1D;
|
|| expression==REDUCE_1D;
|
||||||
@@ -74,30 +74,30 @@ namespace isaac
|
|||||||
{
|
{
|
||||||
switch(op.type_family)
|
switch(op.type_family)
|
||||||
{
|
{
|
||||||
case OPERATOR_UNARY_TYPE_FAMILY:
|
case UNARY_TYPE_FAMILY:
|
||||||
if(is_mmprod(left))
|
if(is_mmprod(left))
|
||||||
return ELEMENTWISE_2D;
|
return ELEMENTWISE_2D;
|
||||||
return left;
|
return left;
|
||||||
case OPERATOR_BINARY_TYPE_FAMILY:
|
case BINARY_TYPE_FAMILY:
|
||||||
if(left == REDUCE_2D_ROWS || right == REDUCE_2D_ROWS) return REDUCE_2D_ROWS;
|
if(left == REDUCE_2D_ROWS || right == REDUCE_2D_ROWS) return REDUCE_2D_ROWS;
|
||||||
else if(left == REDUCE_2D_COLS || right == REDUCE_2D_COLS) return REDUCE_2D_COLS;
|
else if(left == REDUCE_2D_COLS || right == REDUCE_2D_COLS) return REDUCE_2D_COLS;
|
||||||
else if(left == REDUCE_1D || right == REDUCE_1D) return REDUCE_1D;
|
else if(left == REDUCE_1D || right == REDUCE_1D) return REDUCE_1D;
|
||||||
else if(left == ELEMENTWISE_2D || right == ELEMENTWISE_2D) return ELEMENTWISE_2D;
|
else if(left == ELEMENTWISE_2D || right == ELEMENTWISE_2D) return ELEMENTWISE_2D;
|
||||||
else if(left == ELEMENTWISE_1D || right == ELEMENTWISE_1D) return op.type==OPERATOR_OUTER_PROD_TYPE?ELEMENTWISE_2D:ELEMENTWISE_1D;
|
else if(left == ELEMENTWISE_1D || right == ELEMENTWISE_1D) return op.type==OUTER_PROD_TYPE?ELEMENTWISE_2D:ELEMENTWISE_1D;
|
||||||
else if(is_mmprod(left) || is_mmprod(right)) return ELEMENTWISE_2D;
|
else if(is_mmprod(left) || is_mmprod(right)) return ELEMENTWISE_2D;
|
||||||
else if(right == INVALID_EXPRESSION_TYPE) return left;
|
else if(right == INVALID_EXPRESSION_TYPE) return left;
|
||||||
else if(left == INVALID_EXPRESSION_TYPE) return right;
|
else if(left == INVALID_EXPRESSION_TYPE) return right;
|
||||||
throw;
|
throw;
|
||||||
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
|
case VECTOR_DOT_TYPE_FAMILY:
|
||||||
return REDUCE_1D;
|
return REDUCE_1D;
|
||||||
case OPERATOR_ROWS_DOT_TYPE_FAMILY:
|
case ROWS_DOT_TYPE_FAMILY:
|
||||||
return REDUCE_2D_ROWS;
|
return REDUCE_2D_ROWS;
|
||||||
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
|
case COLUMNS_DOT_TYPE_FAMILY:
|
||||||
return REDUCE_2D_COLS;
|
return REDUCE_2D_COLS;
|
||||||
case OPERATOR_GEMM_TYPE_FAMILY:
|
case MATRIX_PRODUCT_TYPE_FAMILY:
|
||||||
if(op.type==OPERATOR_GEMM_NN_TYPE) return MATRIX_PRODUCT_NN;
|
if(op.type==MATRIX_PRODUCT_NN_TYPE) return MATRIX_PRODUCT_NN;
|
||||||
else if(op.type==OPERATOR_GEMM_TN_TYPE) return MATRIX_PRODUCT_TN;
|
else if(op.type==MATRIX_PRODUCT_TN_TYPE) return MATRIX_PRODUCT_TN;
|
||||||
else if(op.type==OPERATOR_GEMM_NT_TYPE) return MATRIX_PRODUCT_NT;
|
else if(op.type==MATRIX_PRODUCT_NT_TYPE) return MATRIX_PRODUCT_NT;
|
||||||
else return MATRIX_PRODUCT_TT;
|
else return MATRIX_PRODUCT_TT;
|
||||||
default:
|
default:
|
||||||
throw;
|
throw;
|
||||||
@@ -115,11 +115,11 @@ namespace isaac
|
|||||||
auto ng1 = [](shape_t const & shape){ size_t res = 0 ; for(size_t i = 0 ; i < shape.size() ; ++i) res += (shape[i] > 1); return res;};
|
auto ng1 = [](shape_t const & shape){ size_t res = 0 ; for(size_t i = 0 ; i < shape.size() ; ++i) res += (shape[i] > 1); return res;};
|
||||||
//Left
|
//Left
|
||||||
expression_type type_left = INVALID_EXPRESSION_TYPE;
|
expression_type type_left = INVALID_EXPRESSION_TYPE;
|
||||||
if (node.lhs.type_family == COMPOSITE_OPERATOR_FAMILY)
|
if (node.lhs.subtype == COMPOSITE_OPERATOR_TYPE)
|
||||||
parse(array, node.lhs.node_index, breakpoints, type_left, false);
|
parse(array, node.lhs.node_index, breakpoints, type_left, false);
|
||||||
else if(node.lhs.subtype == DENSE_ARRAY_TYPE)
|
else if(node.lhs.subtype == DENSE_ARRAY_TYPE)
|
||||||
{
|
{
|
||||||
if(node.op.type==OPERATOR_MATRIX_ROW_TYPE || node.op.type==OPERATOR_MATRIX_COLUMN_TYPE || ng1(node.lhs.array->shape())<=1)
|
if(node.op.type==MATRIX_ROW_TYPE || node.op.type==MATRIX_COLUMN_TYPE || ng1(node.lhs.array->shape())<=1)
|
||||||
type_left = ELEMENTWISE_1D;
|
type_left = ELEMENTWISE_1D;
|
||||||
else
|
else
|
||||||
type_left = ELEMENTWISE_2D;
|
type_left = ELEMENTWISE_2D;
|
||||||
@@ -127,11 +127,11 @@ namespace isaac
|
|||||||
|
|
||||||
//Right
|
//Right
|
||||||
expression_type type_right = INVALID_EXPRESSION_TYPE;
|
expression_type type_right = INVALID_EXPRESSION_TYPE;
|
||||||
if (node.rhs.type_family == COMPOSITE_OPERATOR_FAMILY)
|
if (node.rhs.subtype == COMPOSITE_OPERATOR_TYPE)
|
||||||
parse(array, node.rhs.node_index, breakpoints, type_right, false);
|
parse(array, node.rhs.node_index, breakpoints, type_right, false);
|
||||||
else if(node.rhs.subtype == DENSE_ARRAY_TYPE)
|
else if(node.rhs.subtype == DENSE_ARRAY_TYPE)
|
||||||
{
|
{
|
||||||
if(node.op.type==OPERATOR_MATRIX_ROW_TYPE || node.op.type==OPERATOR_MATRIX_COLUMN_TYPE || ng1(node.rhs.array->shape())<=1)
|
if(node.op.type==MATRIX_ROW_TYPE || node.op.type==MATRIX_COLUMN_TYPE || ng1(node.rhs.array->shape())<=1)
|
||||||
type_right = ELEMENTWISE_1D;
|
type_right = ELEMENTWISE_1D;
|
||||||
else
|
else
|
||||||
type_right = ELEMENTWISE_2D;
|
type_right = ELEMENTWISE_2D;
|
||||||
@@ -160,7 +160,7 @@ namespace isaac
|
|||||||
std::vector<std::shared_ptr<array> > temporaries_;
|
std::vector<std::shared_ptr<array> > temporaries_;
|
||||||
|
|
||||||
expression_type final_type;
|
expression_type final_type;
|
||||||
//GEMM
|
//MATRIX_PRODUCT
|
||||||
if(symbolic::preset::matrix_product::args args = symbolic::preset::matrix_product::check(tree, rootidx)){
|
if(symbolic::preset::matrix_product::args args = symbolic::preset::matrix_product::check(tree, rootidx)){
|
||||||
final_type = args.type;
|
final_type = args.type;
|
||||||
}
|
}
|
||||||
@@ -208,10 +208,10 @@ namespace isaac
|
|||||||
}
|
}
|
||||||
temporaries_.push_back(tmp);
|
temporaries_.push_back(tmp);
|
||||||
|
|
||||||
tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE;
|
tree[rootidx].op.type = ASSIGN_TYPE;
|
||||||
fill(tree[rootidx].lhs, (array&)*tmp);
|
fill(tree[rootidx].lhs, (array&)*tmp);
|
||||||
tree[rootidx].rhs = *it->second;
|
tree[rootidx].rhs = *it->second;
|
||||||
tree[rootidx].rhs.type_family = it->second->type_family;
|
tree[rootidx].rhs.subtype = it->second->subtype;
|
||||||
|
|
||||||
//Execute
|
//Execute
|
||||||
profile->execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
profile->execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
||||||
|
@@ -10,22 +10,19 @@ namespace isaac
|
|||||||
|
|
||||||
void fill(lhs_rhs_element &x, invalid_node)
|
void fill(lhs_rhs_element &x, invalid_node)
|
||||||
{
|
{
|
||||||
x.type_family = INVALID_TYPE_FAMILY;
|
|
||||||
x.subtype = INVALID_SUBTYPE;
|
x.subtype = INVALID_SUBTYPE;
|
||||||
x.dtype = INVALID_NUMERIC_TYPE;
|
x.dtype = INVALID_NUMERIC_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
void fill(lhs_rhs_element & x, std::size_t node_index)
|
void fill(lhs_rhs_element & x, std::size_t node_index)
|
||||||
{
|
{
|
||||||
x.type_family = COMPOSITE_OPERATOR_FAMILY;
|
x.subtype = COMPOSITE_OPERATOR_TYPE;
|
||||||
x.subtype = INVALID_SUBTYPE;
|
|
||||||
x.dtype = INVALID_NUMERIC_TYPE;
|
x.dtype = INVALID_NUMERIC_TYPE;
|
||||||
x.node_index = node_index;
|
x.node_index = node_index;
|
||||||
}
|
}
|
||||||
|
|
||||||
void fill(lhs_rhs_element & x, for_idx_t index)
|
void fill(lhs_rhs_element & x, for_idx_t index)
|
||||||
{
|
{
|
||||||
x.type_family = PLACEHOLDER_TYPE_FAMILY;
|
|
||||||
x.subtype = FOR_LOOP_INDEX_TYPE;
|
x.subtype = FOR_LOOP_INDEX_TYPE;
|
||||||
x.dtype = INVALID_NUMERIC_TYPE;
|
x.dtype = INVALID_NUMERIC_TYPE;
|
||||||
x.for_idx = index;
|
x.for_idx = index;
|
||||||
@@ -33,7 +30,6 @@ void fill(lhs_rhs_element & x, for_idx_t index)
|
|||||||
|
|
||||||
void fill(lhs_rhs_element & x, array_base const & a)
|
void fill(lhs_rhs_element & x, array_base const & a)
|
||||||
{
|
{
|
||||||
x.type_family = ARRAY_TYPE_FAMILY;
|
|
||||||
x.subtype = DENSE_ARRAY_TYPE;
|
x.subtype = DENSE_ARRAY_TYPE;
|
||||||
x.dtype = a.dtype();
|
x.dtype = a.dtype();
|
||||||
x.array = (array_base*)&a;
|
x.array = (array_base*)&a;
|
||||||
@@ -41,7 +37,6 @@ void fill(lhs_rhs_element & x, array_base const & a)
|
|||||||
|
|
||||||
void fill(lhs_rhs_element & x, value_scalar const & v)
|
void fill(lhs_rhs_element & x, value_scalar const & v)
|
||||||
{
|
{
|
||||||
x.type_family = VALUE_TYPE_FAMILY;
|
|
||||||
x.dtype = v.dtype();
|
x.dtype = v.dtype();
|
||||||
x.subtype = VALUE_SCALAR_TYPE;
|
x.subtype = VALUE_SCALAR_TYPE;
|
||||||
x.vscalar = v.values();
|
x.vscalar = v.values();
|
||||||
@@ -51,7 +46,7 @@ lhs_rhs_element::lhs_rhs_element(){}
|
|||||||
|
|
||||||
//
|
//
|
||||||
op_element::op_element() {}
|
op_element::op_element() {}
|
||||||
op_element::op_element(operation_node_type_family const & _type_family, operation_node_type const & _type) : type_family(_type_family), type(_type){}
|
op_element::op_element(operation_type_family const & _type_family, operation_type const & _type) : type_family(_type_family), type(_type){}
|
||||||
|
|
||||||
//
|
//
|
||||||
math_expression::math_expression(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op)
|
math_expression::math_expression(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op)
|
||||||
@@ -122,8 +117,8 @@ math_expression::math_expression(math_expression const & lhs, math_expression co
|
|||||||
tree_[root_].op = op;
|
tree_[root_].op = op;
|
||||||
fill(tree_[root_].rhs, lsize + rhs.root_);
|
fill(tree_[root_].rhs, lsize + rhs.root_);
|
||||||
for(container_type::iterator it = tree_.begin() + lsize ; it != tree_.end() - 1 ; ++it){
|
for(container_type::iterator it = tree_.begin() + lsize ; it != tree_.end() - 1 ; ++it){
|
||||||
if(it->lhs.type_family==COMPOSITE_OPERATOR_FAMILY) it->lhs.node_index+=lsize;
|
if(it->lhs.subtype==COMPOSITE_OPERATOR_TYPE) it->lhs.node_index+=lsize;
|
||||||
if(it->rhs.type_family==COMPOSITE_OPERATOR_FAMILY) it->rhs.node_index+=lsize;
|
if(it->rhs.subtype==COMPOSITE_OPERATOR_TYPE) it->rhs.node_index+=lsize;
|
||||||
}
|
}
|
||||||
root_ = tree_.size() - 1;
|
root_ = tree_.size() - 1;
|
||||||
}
|
}
|
||||||
@@ -181,17 +176,17 @@ int_t math_expression::dim() const
|
|||||||
//}
|
//}
|
||||||
|
|
||||||
math_expression math_expression::operator-()
|
math_expression math_expression::operator-()
|
||||||
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), *context_, dtype_, shape_); }
|
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), *context_, dtype_, shape_); }
|
||||||
|
|
||||||
math_expression math_expression::operator!()
|
math_expression math_expression::operator!()
|
||||||
{ return math_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_NEGATE_TYPE), *context_, INT_TYPE, shape_); }
|
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), *context_, INT_TYPE, shape_); }
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
||||||
math_expression::node const & lhs_most(math_expression::container_type const & array, math_expression::node const & init)
|
math_expression::node const & lhs_most(math_expression::container_type const & array, math_expression::node const & init)
|
||||||
{
|
{
|
||||||
math_expression::node const * current = &init;
|
math_expression::node const * current = &init;
|
||||||
while (current->lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
while (current->lhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||||
current = &array[current->lhs.node_index];
|
current = &array[current->lhs.node_index];
|
||||||
return *current;
|
return *current;
|
||||||
}
|
}
|
||||||
@@ -200,8 +195,8 @@ math_expression::node const & lhs_most(math_expression::container_type const & a
|
|||||||
{ return lhs_most(array, array[root]); }
|
{ return lhs_most(array, array[root]); }
|
||||||
|
|
||||||
//
|
//
|
||||||
math_expression for_idx_t::operator=(value_scalar const & r) const { return math_expression(*this, r, op_element(OPERATOR_BINARY_TYPE_FAMILY,OPERATOR_ASSIGN_TYPE), r.dtype()); }
|
math_expression for_idx_t::operator=(value_scalar const & r) const { return math_expression(*this, r, op_element(BINARY_TYPE_FAMILY,ASSIGN_TYPE), r.dtype()); }
|
||||||
math_expression for_idx_t::operator=(math_expression const & r) const { return math_expression(*this, r, op_element(OPERATOR_BINARY_TYPE_FAMILY,OPERATOR_ASSIGN_TYPE), r.context(), r.dtype(), r.shape()); }
|
math_expression for_idx_t::operator=(math_expression const & r) const { return math_expression(*this, r, op_element(BINARY_TYPE_FAMILY,ASSIGN_TYPE), r.context(), r.dtype(), r.shape()); }
|
||||||
|
|
||||||
math_expression for_idx_t::operator+=(value_scalar const & r) const { return *this = *this + r; }
|
math_expression for_idx_t::operator+=(value_scalar const & r) const { return *this = *this + r; }
|
||||||
math_expression for_idx_t::operator-=(value_scalar const & r) const { return *this = *this - r; }
|
math_expression for_idx_t::operator-=(value_scalar const & r) const { return *this = *this - r; }
|
||||||
|
@@ -10,7 +10,7 @@ namespace isaac
|
|||||||
|
|
||||||
#define ISAAC_MAP_TO_STRING(NAME) case NAME: return #NAME
|
#define ISAAC_MAP_TO_STRING(NAME) case NAME: return #NAME
|
||||||
|
|
||||||
inline std::string to_string(math_expression_node_subtype const & f)
|
inline std::string to_string(node_type const & f)
|
||||||
{
|
{
|
||||||
switch(f)
|
switch(f)
|
||||||
{
|
{
|
||||||
@@ -23,7 +23,7 @@ inline std::string to_string(math_expression_node_subtype const & f)
|
|||||||
|
|
||||||
inline std::string to_string(lhs_rhs_element const & e)
|
inline std::string to_string(lhs_rhs_element const & e)
|
||||||
{
|
{
|
||||||
if(e.type_family==COMPOSITE_OPERATOR_FAMILY)
|
if(e.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||||
{
|
{
|
||||||
return"COMPOSITE [" + tools::to_string(e.node_index) + "]";
|
return"COMPOSITE [" + tools::to_string(e.node_index) + "]";
|
||||||
}
|
}
|
||||||
@@ -53,10 +53,10 @@ namespace detail
|
|||||||
|
|
||||||
os << "Node " << node_index << ": " << current_node << std::endl;
|
os << "Node " << node_index << ": " << current_node << std::endl;
|
||||||
|
|
||||||
if (current_node.lhs.type_family == COMPOSITE_OPERATOR_FAMILY)
|
if (current_node.lhs.subtype == COMPOSITE_OPERATOR_TYPE)
|
||||||
print_node(os, s, current_node.lhs.node_index, indent+1);
|
print_node(os, s, current_node.lhs.node_index, indent+1);
|
||||||
|
|
||||||
if (current_node.rhs.type_family == COMPOSITE_OPERATOR_FAMILY)
|
if (current_node.rhs.subtype == COMPOSITE_OPERATOR_TYPE)
|
||||||
print_node(os, s, current_node.rhs.node_index, indent+1);
|
print_node(os, s, current_node.rhs.node_index, indent+1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -12,33 +12,33 @@ namespace preset
|
|||||||
void matrix_product::handle_node(math_expression::container_type const & tree, size_t rootidx, args & a)
|
void matrix_product::handle_node(math_expression::container_type const & tree, size_t rootidx, args & a)
|
||||||
{
|
{
|
||||||
//Matrix-Matrix product node
|
//Matrix-Matrix product node
|
||||||
if(tree[rootidx].op.type_family==OPERATOR_GEMM_TYPE_FAMILY)
|
if(tree[rootidx].op.type_family==MATRIX_PRODUCT_TYPE_FAMILY)
|
||||||
{
|
{
|
||||||
if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY) a.A = &tree[rootidx].lhs;
|
if(tree[rootidx].lhs.subtype==DENSE_ARRAY_TYPE) a.A = &tree[rootidx].lhs;
|
||||||
if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs;
|
if(tree[rootidx].rhs.subtype==DENSE_ARRAY_TYPE) a.B = &tree[rootidx].rhs;
|
||||||
switch(tree[rootidx].op.type)
|
switch(tree[rootidx].op.type)
|
||||||
{
|
{
|
||||||
case OPERATOR_GEMM_NN_TYPE: a.type = MATRIX_PRODUCT_NN; break;
|
case MATRIX_PRODUCT_NN_TYPE: a.type = MATRIX_PRODUCT_NN; break;
|
||||||
case OPERATOR_GEMM_NT_TYPE: a.type = MATRIX_PRODUCT_NT; break;
|
case MATRIX_PRODUCT_NT_TYPE: a.type = MATRIX_PRODUCT_NT; break;
|
||||||
case OPERATOR_GEMM_TN_TYPE: a.type = MATRIX_PRODUCT_TN; break;
|
case MATRIX_PRODUCT_TN_TYPE: a.type = MATRIX_PRODUCT_TN; break;
|
||||||
case OPERATOR_GEMM_TT_TYPE: a.type = MATRIX_PRODUCT_TT; break;
|
case MATRIX_PRODUCT_TT_TYPE: a.type = MATRIX_PRODUCT_TT; break;
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//Scalar multiplication node
|
//Scalar multiplication node
|
||||||
if(tree[rootidx].op.type==OPERATOR_MULT_TYPE)
|
if(tree[rootidx].op.type==MULT_TYPE)
|
||||||
{
|
{
|
||||||
//alpha*PROD
|
//alpha*PROD
|
||||||
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY
|
if(tree[rootidx].lhs.subtype==VALUE_SCALAR_TYPE && tree[rootidx].rhs.subtype==COMPOSITE_OPERATOR_TYPE
|
||||||
&& tree[tree[rootidx].rhs.node_index].op.type_family==OPERATOR_GEMM_TYPE_FAMILY)
|
&& tree[tree[rootidx].rhs.node_index].op.type_family==MATRIX_PRODUCT_TYPE_FAMILY)
|
||||||
{
|
{
|
||||||
a.alpha = value_scalar(tree[rootidx].lhs.vscalar, tree[rootidx].lhs.dtype);
|
a.alpha = value_scalar(tree[rootidx].lhs.vscalar, tree[rootidx].lhs.dtype);
|
||||||
handle_node(tree, tree[rootidx].rhs.node_index, a);
|
handle_node(tree, tree[rootidx].rhs.node_index, a);
|
||||||
}
|
}
|
||||||
|
|
||||||
//beta*C
|
//beta*C
|
||||||
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY)
|
if(tree[rootidx].lhs.subtype==VALUE_SCALAR_TYPE && tree[rootidx].rhs.subtype==DENSE_ARRAY_TYPE)
|
||||||
{
|
{
|
||||||
a.beta = value_scalar(tree[rootidx].lhs.vscalar, tree[rootidx].lhs.dtype);
|
a.beta = value_scalar(tree[rootidx].lhs.vscalar, tree[rootidx].lhs.dtype);
|
||||||
a.C = &tree[rootidx].rhs;
|
a.C = &tree[rootidx].rhs;
|
||||||
@@ -55,26 +55,26 @@ matrix_product::args matrix_product::check(math_expression::container_type const
|
|||||||
return result;
|
return result;
|
||||||
result.alpha = value_scalar(1, dtype);
|
result.alpha = value_scalar(1, dtype);
|
||||||
result.beta = value_scalar(0, dtype);
|
result.beta = value_scalar(0, dtype);
|
||||||
if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
if(tree[rootidx].rhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||||
{
|
{
|
||||||
rootidx = tree[rootidx].rhs.node_index;
|
rootidx = tree[rootidx].rhs.node_index;
|
||||||
bool is_add = tree[rootidx].op.type==OPERATOR_ADD_TYPE;
|
bool is_add = tree[rootidx].op.type==ADD_TYPE;
|
||||||
bool is_sub = tree[rootidx].op.type==OPERATOR_SUB_TYPE;
|
bool is_sub = tree[rootidx].op.type==SUB_TYPE;
|
||||||
//Form X +- Y"
|
//Form X +- Y"
|
||||||
if(is_add || is_sub)
|
if(is_add || is_sub)
|
||||||
{
|
{
|
||||||
if(tree[rootidx].lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
if(tree[rootidx].lhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||||
handle_node(tree, tree[rootidx].lhs.node_index, result);
|
handle_node(tree, tree[rootidx].lhs.node_index, result);
|
||||||
else if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY)
|
else if(tree[rootidx].lhs.subtype==DENSE_ARRAY_TYPE)
|
||||||
{
|
{
|
||||||
result.C = &tree[rootidx].lhs;
|
result.C = &tree[rootidx].lhs;
|
||||||
result.beta = value_scalar(1, dtype);
|
result.beta = value_scalar(1, dtype);
|
||||||
result.alpha = value_scalar(is_add?1:-1, dtype);
|
result.alpha = value_scalar(is_add?1:-1, dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
if(tree[rootidx].rhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||||
handle_node(tree, tree[rootidx].rhs.node_index, result);
|
handle_node(tree, tree[rootidx].rhs.node_index, result);
|
||||||
else if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY)
|
else if(tree[rootidx].rhs.subtype==DENSE_ARRAY_TYPE)
|
||||||
{
|
{
|
||||||
result.C = &tree[rootidx].rhs;
|
result.C = &tree[rootidx].rhs;
|
||||||
result.alpha = value_scalar(1, dtype);
|
result.alpha = value_scalar(1, dtype);
|
||||||
|
@@ -169,7 +169,7 @@ extern "C"
|
|||||||
//BLAS3
|
//BLAS3
|
||||||
//*****************
|
//*****************
|
||||||
|
|
||||||
#define MAKE_GEMM(TYPE_CHAR, TYPE_ISAAC, TYPE_CL) \
|
#define MAKE_MATRIX_PRODUCT(TYPE_CHAR, TYPE_ISAAC, TYPE_CL) \
|
||||||
clblasStatus clblas ## TYPE_CHAR ## gemm(clblasOrder order, clblasTranspose transA, clblasTranspose transB,\
|
clblasStatus clblas ## TYPE_CHAR ## gemm(clblasOrder order, clblasTranspose transA, clblasTranspose transB,\
|
||||||
size_t M, size_t N, size_t K,\
|
size_t M, size_t N, size_t K,\
|
||||||
TYPE_CL alpha, const cl_mem cmA, size_t offA, size_t lda,\
|
TYPE_CL alpha, const cl_mem cmA, size_t offA, size_t lda,\
|
||||||
@@ -208,8 +208,8 @@ extern "C"
|
|||||||
return clblasSuccess;\
|
return clblasSuccess;\
|
||||||
}
|
}
|
||||||
|
|
||||||
MAKE_GEMM(S, sc::FLOAT_TYPE, cl_float)
|
MAKE_MATRIX_PRODUCT(S, sc::FLOAT_TYPE, cl_float)
|
||||||
MAKE_GEMM(D, sc::DOUBLE_TYPE, cl_double)
|
MAKE_MATRIX_PRODUCT(D, sc::DOUBLE_TYPE, cl_double)
|
||||||
|
|
||||||
#undef DOT
|
#undef DOT
|
||||||
|
|
||||||
|
@@ -194,7 +194,7 @@ extern "C"
|
|||||||
//BLAS3
|
//BLAS3
|
||||||
//*****************
|
//*****************
|
||||||
|
|
||||||
#define MAKE_GEMM(TYPE_CHAR, TYPE_ISAAC, TYPE_CU) \
|
#define MAKE_MATRIX_PRODUCT(TYPE_CHAR, TYPE_ISAAC, TYPE_CU) \
|
||||||
void cublas ## TYPE_CHAR ## gemm (char transa, char transb, int m, int n, int k,\
|
void cublas ## TYPE_CHAR ## gemm (char transa, char transb, int m, int n, int k,\
|
||||||
TYPE_CU alpha, const TYPE_CU *A, int lda,\
|
TYPE_CU alpha, const TYPE_CU *A, int lda,\
|
||||||
const TYPE_CU *B, int ldb, TYPE_CU beta, TYPE_CU *C,\
|
const TYPE_CU *B, int ldb, TYPE_CU beta, TYPE_CU *C,\
|
||||||
@@ -235,6 +235,6 @@ extern "C"
|
|||||||
return CUBLAS_STATUS_SUCCESS;\
|
return CUBLAS_STATUS_SUCCESS;\
|
||||||
}
|
}
|
||||||
|
|
||||||
MAKE_GEMM(S, sc::FLOAT_TYPE, cl_float)
|
MAKE_MATRIX_PRODUCT(S, sc::FLOAT_TYPE, cl_float)
|
||||||
MAKE_GEMM(D, sc::DOUBLE_TYPE, cl_double)
|
MAKE_MATRIX_PRODUCT(D, sc::DOUBLE_TYPE, cl_double)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user