Code quality: renamed math_expression -> expression_tree
This commit is contained in:
@@ -59,30 +59,30 @@ public:
|
||||
|
||||
//Numeric operators
|
||||
array_base& operator=(array_base const &);
|
||||
array_base& operator=(math_expression const &);
|
||||
array_base& operator=(expression_tree const &);
|
||||
array_base& operator=(execution_handler const &);
|
||||
template<class T>
|
||||
array_base & operator=(std::vector<T> const & rhs);
|
||||
array_base & operator=(value_scalar const & rhs);
|
||||
|
||||
math_expression operator-();
|
||||
math_expression operator!();
|
||||
expression_tree operator-();
|
||||
expression_tree operator!();
|
||||
|
||||
array_base& operator+=(value_scalar const &);
|
||||
array_base& operator+=(array_base const &);
|
||||
array_base& operator+=(math_expression const &);
|
||||
array_base& operator+=(expression_tree const &);
|
||||
array_base& operator-=(value_scalar const &);
|
||||
array_base& operator-=(array_base const &);
|
||||
array_base& operator-=(math_expression const &);
|
||||
array_base& operator-=(expression_tree const &);
|
||||
array_base& operator*=(value_scalar const &);
|
||||
array_base& operator*=(array_base const &);
|
||||
array_base& operator*=(math_expression const &);
|
||||
array_base& operator*=(expression_tree const &);
|
||||
array_base& operator/=(value_scalar const &);
|
||||
array_base& operator/=(array_base const &);
|
||||
array_base& operator/=(math_expression const &);
|
||||
array_base& operator/=(expression_tree const &);
|
||||
|
||||
//Indexing (1D)
|
||||
math_expression operator[](for_idx_t idx) const;
|
||||
expression_tree operator[](for_idx_t idx) const;
|
||||
const scalar operator[](int_t) const;
|
||||
scalar operator[](int_t);
|
||||
view operator[](slice const &);
|
||||
@@ -105,7 +105,7 @@ protected:
|
||||
driver::Buffer data_;
|
||||
|
||||
public:
|
||||
math_expression T;
|
||||
expression_tree T;
|
||||
};
|
||||
|
||||
class ISAACAPI array : public array_base
|
||||
@@ -115,7 +115,7 @@ public:
|
||||
//Copy Constructor
|
||||
array(array_base const &);
|
||||
array(array const &);
|
||||
array(math_expression const & proxy);
|
||||
array(expression_tree const & proxy);
|
||||
using array_base::operator=;
|
||||
};
|
||||
|
||||
@@ -132,7 +132,7 @@ public:
|
||||
class ISAACAPI scalar : public array_base
|
||||
{
|
||||
friend value_scalar::value_scalar(const scalar &);
|
||||
friend value_scalar::value_scalar(const math_expression &);
|
||||
friend value_scalar::value_scalar(const expression_tree &);
|
||||
private:
|
||||
void inject(values_holder&) const;
|
||||
template<class T> T cast() const;
|
||||
@@ -140,7 +140,7 @@ public:
|
||||
explicit scalar(numeric_type dtype, const driver::Buffer &data, int_t offset);
|
||||
explicit scalar(value_scalar value, driver::Context const & context = driver::backend::contexts::get_default());
|
||||
explicit scalar(numeric_type dtype, driver::Context const & context = driver::backend::contexts::get_default());
|
||||
scalar(math_expression const & proxy);
|
||||
scalar(expression_tree const & proxy);
|
||||
scalar& operator=(value_scalar const &);
|
||||
// scalar& operator=(scalar const & s);
|
||||
using array_base::operator =;
|
||||
@@ -178,24 +178,24 @@ template<class T> ISAACAPI void copy(array_base const & gA, std::vector<T> & cA,
|
||||
//Binary operators
|
||||
|
||||
#define ISAAC_DECLARE_ELEMENT_BINARY_OPERATOR(OPNAME) \
|
||||
ISAACAPI math_expression OPNAME (array_base const & x, math_expression const & y);\
|
||||
ISAACAPI math_expression OPNAME (array_base const & x, value_scalar const & y);\
|
||||
ISAACAPI math_expression OPNAME (array_base const & x, for_idx_t const & y);\
|
||||
ISAACAPI math_expression OPNAME (array_base const & x, array_base const & y);\
|
||||
ISAACAPI expression_tree OPNAME (array_base const & x, expression_tree const & y);\
|
||||
ISAACAPI expression_tree OPNAME (array_base const & x, value_scalar const & y);\
|
||||
ISAACAPI expression_tree OPNAME (array_base const & x, for_idx_t const & y);\
|
||||
ISAACAPI expression_tree OPNAME (array_base const & x, array_base const & y);\
|
||||
\
|
||||
ISAACAPI math_expression OPNAME (math_expression const & x, math_expression const & y);\
|
||||
ISAACAPI math_expression OPNAME (math_expression const & x, value_scalar const & y);\
|
||||
ISAACAPI math_expression OPNAME (math_expression const & x, for_idx_t const & y);\
|
||||
ISAACAPI math_expression OPNAME (math_expression const & x, array_base const & y);\
|
||||
ISAACAPI expression_tree OPNAME (expression_tree const & x, expression_tree const & y);\
|
||||
ISAACAPI expression_tree OPNAME (expression_tree const & x, value_scalar const & y);\
|
||||
ISAACAPI expression_tree OPNAME (expression_tree const & x, for_idx_t const & y);\
|
||||
ISAACAPI expression_tree OPNAME (expression_tree const & x, array_base const & y);\
|
||||
\
|
||||
ISAACAPI math_expression OPNAME (value_scalar const & y, math_expression const & x);\
|
||||
ISAACAPI math_expression OPNAME (value_scalar const & y, for_idx_t const & x);\
|
||||
ISAACAPI math_expression OPNAME (value_scalar const & y, array_base const & x);\
|
||||
ISAACAPI expression_tree OPNAME (value_scalar const & y, expression_tree const & x);\
|
||||
ISAACAPI expression_tree OPNAME (value_scalar const & y, for_idx_t const & x);\
|
||||
ISAACAPI expression_tree OPNAME (value_scalar const & y, array_base const & x);\
|
||||
\
|
||||
ISAACAPI math_expression OPNAME (for_idx_t const & y, math_expression const & x);\
|
||||
ISAACAPI math_expression OPNAME (for_idx_t const & y, for_idx_t const & x);\
|
||||
ISAACAPI math_expression OPNAME (for_idx_t const & y, value_scalar const & x);\
|
||||
ISAACAPI math_expression OPNAME (for_idx_t const & y, array_base const & x);
|
||||
ISAACAPI expression_tree OPNAME (for_idx_t const & y, expression_tree const & x);\
|
||||
ISAACAPI expression_tree OPNAME (for_idx_t const & y, for_idx_t const & x);\
|
||||
ISAACAPI expression_tree OPNAME (for_idx_t const & y, value_scalar const & x);\
|
||||
ISAACAPI expression_tree OPNAME (for_idx_t const & y, array_base const & x);
|
||||
|
||||
ISAAC_DECLARE_ELEMENT_BINARY_OPERATOR(operator +)
|
||||
ISAAC_DECLARE_ELEMENT_BINARY_OPERATOR(operator -)
|
||||
@@ -221,29 +221,29 @@ ISAAC_DECLARE_ELEMENT_BINARY_OPERATOR(assign)
|
||||
#undef ISAAC_DECLARE_ELEMENT_BINARY_OPERATOR
|
||||
|
||||
#define ISAAC_DECLARE_ROT(LTYPE, RTYPE, CTYPE, STYPE) \
|
||||
math_expression rot(LTYPE const & x, RTYPE const & y, CTYPE const & c, STYPE const & s);
|
||||
expression_tree rot(LTYPE const & x, RTYPE const & y, CTYPE const & c, STYPE const & s);
|
||||
|
||||
ISAAC_DECLARE_ROT(array_base, array_base, scalar, scalar)
|
||||
ISAAC_DECLARE_ROT(math_expression, array_base, scalar, scalar)
|
||||
ISAAC_DECLARE_ROT(array_base, math_expression, scalar, scalar)
|
||||
ISAAC_DECLARE_ROT(math_expression, math_expression, scalar, scalar)
|
||||
ISAAC_DECLARE_ROT(expression_tree, array_base, scalar, scalar)
|
||||
ISAAC_DECLARE_ROT(array_base, expression_tree, scalar, scalar)
|
||||
ISAAC_DECLARE_ROT(expression_tree, expression_tree, scalar, scalar)
|
||||
|
||||
ISAAC_DECLARE_ROT(array_base, array_base, value_scalar, value_scalar)
|
||||
ISAAC_DECLARE_ROT(math_expression, array_base, value_scalar, value_scalar)
|
||||
ISAAC_DECLARE_ROT(array_base, math_expression, value_scalar, value_scalar)
|
||||
ISAAC_DECLARE_ROT(math_expression, math_expression, value_scalar, value_scalar)
|
||||
ISAAC_DECLARE_ROT(expression_tree, array_base, value_scalar, value_scalar)
|
||||
ISAAC_DECLARE_ROT(array_base, expression_tree, value_scalar, value_scalar)
|
||||
ISAAC_DECLARE_ROT(expression_tree, expression_tree, value_scalar, value_scalar)
|
||||
|
||||
ISAAC_DECLARE_ROT(array_base, array_base, math_expression, math_expression)
|
||||
ISAAC_DECLARE_ROT(math_expression, array_base, math_expression, math_expression)
|
||||
ISAAC_DECLARE_ROT(array_base, math_expression, math_expression, math_expression)
|
||||
ISAAC_DECLARE_ROT(math_expression, math_expression, math_expression, math_expression)
|
||||
ISAAC_DECLARE_ROT(array_base, array_base, expression_tree, expression_tree)
|
||||
ISAAC_DECLARE_ROT(expression_tree, array_base, expression_tree, expression_tree)
|
||||
ISAAC_DECLARE_ROT(array_base, expression_tree, expression_tree, expression_tree)
|
||||
ISAAC_DECLARE_ROT(expression_tree, expression_tree, expression_tree, expression_tree)
|
||||
//--------------------------------
|
||||
|
||||
|
||||
//Unary operators
|
||||
#define ISAAC_DECLARE_UNARY_OPERATOR(OPNAME) \
|
||||
ISAACAPI math_expression OPNAME (array_base const & x);\
|
||||
ISAACAPI math_expression OPNAME (math_expression const & x);
|
||||
ISAACAPI expression_tree OPNAME (array_base const & x);\
|
||||
ISAACAPI expression_tree OPNAME (expression_tree const & x);
|
||||
|
||||
ISAAC_DECLARE_UNARY_OPERATOR(abs)
|
||||
ISAAC_DECLARE_UNARY_OPERATOR(acos)
|
||||
@@ -263,21 +263,21 @@ ISAAC_DECLARE_UNARY_OPERATOR(tan)
|
||||
ISAAC_DECLARE_UNARY_OPERATOR(tanh)
|
||||
ISAAC_DECLARE_UNARY_OPERATOR(trans)
|
||||
|
||||
ISAACAPI math_expression cast(array_base const &, numeric_type dtype);
|
||||
ISAACAPI math_expression cast(math_expression const &, numeric_type dtype);
|
||||
ISAACAPI expression_tree cast(array_base const &, numeric_type dtype);
|
||||
ISAACAPI expression_tree cast(expression_tree const &, numeric_type dtype);
|
||||
|
||||
ISAACAPI math_expression norm(array_base const &, unsigned int order = 2);
|
||||
ISAACAPI math_expression norm(math_expression const &, unsigned int order = 2);
|
||||
ISAACAPI expression_tree norm(array_base const &, unsigned int order = 2);
|
||||
ISAACAPI expression_tree norm(expression_tree const &, unsigned int order = 2);
|
||||
|
||||
#undef ISAAC_DECLARE_UNARY_OPERATOR
|
||||
|
||||
ISAACAPI math_expression repmat(array_base const &, int_t const & rep1, int_t const & rep2);
|
||||
ISAACAPI expression_tree repmat(array_base const &, int_t const & rep1, int_t const & rep2);
|
||||
|
||||
//Matrix reduction
|
||||
|
||||
#define ISAAC_DECLARE_DOT(OPNAME) \
|
||||
ISAACAPI math_expression OPNAME(array_base const & M, int_t axis = -1);\
|
||||
ISAACAPI math_expression OPNAME(math_expression const & M, int_t axis = -1);
|
||||
ISAACAPI expression_tree OPNAME(array_base const & M, int_t axis = -1);\
|
||||
ISAACAPI expression_tree OPNAME(expression_tree const & M, int_t axis = -1);
|
||||
|
||||
ISAAC_DECLARE_DOT(sum)
|
||||
ISAAC_DECLARE_DOT(argmax)
|
||||
@@ -286,10 +286,10 @@ ISAAC_DECLARE_DOT((min))
|
||||
ISAAC_DECLARE_DOT(argmin)
|
||||
|
||||
//Fusion
|
||||
ISAACAPI math_expression fuse(math_expression const & x, math_expression const & y);
|
||||
ISAACAPI expression_tree fuse(expression_tree const & x, expression_tree const & y);
|
||||
|
||||
//For
|
||||
ISAACAPI math_expression sfor(math_expression const & start, math_expression const & end, math_expression const & inc, math_expression const & expression);
|
||||
ISAACAPI expression_tree sfor(expression_tree const & start, expression_tree const & end, expression_tree const & inc, expression_tree const & expression);
|
||||
static const for_idx_t _i0{0};
|
||||
static const for_idx_t _i1{1};
|
||||
static const for_idx_t _i2{2};
|
||||
@@ -302,41 +302,41 @@ static const for_idx_t _i8{8};
|
||||
static const for_idx_t _i9{9};
|
||||
|
||||
//Initializers
|
||||
ISAACAPI math_expression eye(int_t, int_t, isaac::numeric_type, driver::Context const & context = driver::backend::contexts::get_default());
|
||||
ISAACAPI math_expression zeros(int_t M, int_t N, numeric_type dtype, driver::Context const & context = driver::backend::contexts::get_default());
|
||||
ISAACAPI expression_tree eye(int_t, int_t, isaac::numeric_type, driver::Context const & context = driver::backend::contexts::get_default());
|
||||
ISAACAPI expression_tree zeros(int_t M, int_t N, numeric_type dtype, driver::Context const & context = driver::backend::contexts::get_default());
|
||||
|
||||
//Swap
|
||||
ISAACAPI void swap(view x, view y);
|
||||
|
||||
//Reshape
|
||||
ISAACAPI math_expression reshape(array_base const &, shape_t const &);
|
||||
ISAACAPI math_expression ravel(array_base const &);
|
||||
ISAACAPI expression_tree reshape(array_base const &, shape_t const &);
|
||||
ISAACAPI expression_tree ravel(array_base const &);
|
||||
|
||||
//diag
|
||||
array diag(array_base & x, int offset = 0);
|
||||
|
||||
//Row
|
||||
ISAACAPI math_expression row(array_base const &, value_scalar const &);
|
||||
ISAACAPI math_expression row(array_base const &, for_idx_t const &);
|
||||
ISAACAPI math_expression row(array_base const &, math_expression const &);
|
||||
ISAACAPI expression_tree row(array_base const &, value_scalar const &);
|
||||
ISAACAPI expression_tree row(array_base const &, for_idx_t const &);
|
||||
ISAACAPI expression_tree row(array_base const &, expression_tree const &);
|
||||
|
||||
ISAACAPI math_expression row(math_expression const &, value_scalar const &);
|
||||
ISAACAPI math_expression row(math_expression const &, for_idx_t const &);
|
||||
ISAACAPI math_expression row(math_expression const &, math_expression const &);
|
||||
ISAACAPI expression_tree row(expression_tree const &, value_scalar const &);
|
||||
ISAACAPI expression_tree row(expression_tree const &, for_idx_t const &);
|
||||
ISAACAPI expression_tree row(expression_tree const &, expression_tree const &);
|
||||
|
||||
//col
|
||||
ISAACAPI math_expression col(array_base const &, value_scalar const &);
|
||||
ISAACAPI math_expression col(array_base const &, for_idx_t const &);
|
||||
ISAACAPI math_expression col(array_base const &, math_expression const &);
|
||||
ISAACAPI expression_tree col(array_base const &, value_scalar const &);
|
||||
ISAACAPI expression_tree col(array_base const &, for_idx_t const &);
|
||||
ISAACAPI expression_tree col(array_base const &, expression_tree const &);
|
||||
|
||||
ISAACAPI math_expression col(math_expression const &, value_scalar const &);
|
||||
ISAACAPI math_expression col(math_expression const &, for_idx_t const &);
|
||||
ISAACAPI math_expression col(math_expression const &, math_expression const &);
|
||||
ISAACAPI expression_tree col(expression_tree const &, value_scalar const &);
|
||||
ISAACAPI expression_tree col(expression_tree const &, for_idx_t const &);
|
||||
ISAACAPI expression_tree col(expression_tree const &, expression_tree const &);
|
||||
|
||||
|
||||
//
|
||||
ISAACAPI std::ostream& operator<<(std::ostream &, array_base const &);
|
||||
ISAACAPI std::ostream& operator<<(std::ostream &, math_expression const &);
|
||||
ISAACAPI std::ostream& operator<<(std::ostream &, expression_tree const &);
|
||||
|
||||
}
|
||||
#endif
|
||||
|
@@ -27,7 +27,7 @@ typedef std::map<mapping_key, std::shared_ptr<mapped_object> > mapping_type;
|
||||
|
||||
/** @brief Mapped Object
|
||||
*
|
||||
* This object populates the symbolic mapping associated with a math_expression. (root_id, LHS|RHS|PARENT) => mapped_object
|
||||
* This object populates the symbolic mapping associated with a expression_tree. (root_id, LHS|RHS|PARENT) => mapped_object
|
||||
* The tree can then be reconstructed in its symbolic form
|
||||
*/
|
||||
class mapped_object
|
||||
@@ -49,9 +49,9 @@ protected:
|
||||
public:
|
||||
struct node_info
|
||||
{
|
||||
node_info(mapping_type const * _mapping, math_expression const * _math_expression, size_t _root_idx);
|
||||
node_info(mapping_type const * _mapping, expression_tree const * _expression_tree, size_t _root_idx);
|
||||
mapping_type const * mapping;
|
||||
isaac::math_expression const * math_expression;
|
||||
isaac::expression_tree const * expression_tree;
|
||||
size_t root_idx;
|
||||
};
|
||||
|
||||
@@ -105,8 +105,8 @@ public:
|
||||
mapped_reduce(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key);
|
||||
|
||||
size_t root_idx() const;
|
||||
isaac::math_expression const & math_expression() const;
|
||||
math_expression::node root_node() const;
|
||||
isaac::expression_tree const & expression_tree() const;
|
||||
expression_tree::node root_node() const;
|
||||
bool is_index_reduction() const;
|
||||
op_element root_op() const;
|
||||
};
|
||||
@@ -253,7 +253,7 @@ public:
|
||||
mapped_cast(operation_type type, unsigned int id);
|
||||
};
|
||||
|
||||
extern mapped_object& get(math_expression::container_type const &, size_t, mapping_type const &, size_t);
|
||||
extern mapped_object& get(expression_tree::container_type const &, size_t, mapping_type const &, size_t);
|
||||
|
||||
}
|
||||
#endif
|
||||
|
@@ -13,8 +13,8 @@ namespace detail
|
||||
{
|
||||
|
||||
bool is_node_leaf(op_element const & op);
|
||||
bool is_scalar_reduce_1d(math_expression::node const & node);
|
||||
bool is_vector_reduce_1d(math_expression::node const & node);
|
||||
bool is_scalar_reduce_1d(expression_tree::node const & node);
|
||||
bool is_vector_reduce_1d(expression_tree::node const & node);
|
||||
bool is_assignment(op_element const & op);
|
||||
bool is_elementwise_operator(op_element const & op);
|
||||
bool is_elementwise_function(op_element const & op);
|
||||
@@ -24,58 +24,58 @@ namespace detail
|
||||
|
||||
class scalar;
|
||||
|
||||
/** @brief base functor class for traversing a math_expression */
|
||||
/** @brief base functor class for traversing a expression_tree */
|
||||
class traversal_functor
|
||||
{
|
||||
public:
|
||||
void call_before_expansion(math_expression const &, std::size_t) const { }
|
||||
void call_after_expansion(math_expression const &, std::size_t) const { }
|
||||
void call_before_expansion(expression_tree const &, std::size_t) const { }
|
||||
void call_after_expansion(expression_tree const &, std::size_t) const { }
|
||||
};
|
||||
|
||||
|
||||
/** @brief Recursively execute a functor on a math_expression */
|
||||
/** @brief Recursively execute a functor on a expression_tree */
|
||||
template<class Fun>
|
||||
inline void traverse(isaac::math_expression const & math_expression, std::size_t root_idx, Fun const & fun, bool inspect)
|
||||
inline void traverse(isaac::expression_tree const & expression_tree, std::size_t root_idx, Fun const & fun, bool inspect)
|
||||
{
|
||||
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
||||
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
|
||||
bool recurse = detail::is_node_leaf(root_node.op)?inspect:true;
|
||||
bool bypass = detail::bypass(root_node.op);
|
||||
|
||||
if(!bypass)
|
||||
fun.call_before_expansion(math_expression, root_idx);
|
||||
fun.call_before_expansion(expression_tree, root_idx);
|
||||
|
||||
//Lhs:
|
||||
if (recurse)
|
||||
{
|
||||
if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||
traverse(math_expression, root_node.lhs.node_index, fun, inspect);
|
||||
traverse(expression_tree, root_node.lhs.node_index, fun, inspect);
|
||||
if (root_node.lhs.subtype != INVALID_SUBTYPE)
|
||||
fun(math_expression, root_idx, LHS_NODE_TYPE);
|
||||
fun(expression_tree, root_idx, LHS_NODE_TYPE);
|
||||
}
|
||||
|
||||
//Self:
|
||||
if(!bypass)
|
||||
fun(math_expression, root_idx, PARENT_NODE_TYPE);
|
||||
fun(expression_tree, root_idx, PARENT_NODE_TYPE);
|
||||
|
||||
//Rhs:
|
||||
if (recurse && root_node.rhs.subtype!=INVALID_SUBTYPE)
|
||||
{
|
||||
if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||
traverse(math_expression, root_node.rhs.node_index, fun, inspect);
|
||||
traverse(expression_tree, root_node.rhs.node_index, fun, inspect);
|
||||
if (root_node.rhs.subtype != INVALID_SUBTYPE)
|
||||
fun(math_expression, root_idx, RHS_NODE_TYPE);
|
||||
fun(expression_tree, root_idx, RHS_NODE_TYPE);
|
||||
}
|
||||
|
||||
if(!bypass)
|
||||
fun.call_after_expansion(math_expression, root_idx);
|
||||
fun.call_after_expansion(expression_tree, root_idx);
|
||||
}
|
||||
|
||||
class filter_fun : public traversal_functor
|
||||
{
|
||||
public:
|
||||
typedef bool (*pred_t)(math_expression::node const & node);
|
||||
typedef bool (*pred_t)(expression_tree::node const & node);
|
||||
filter_fun(pred_t pred, std::vector<size_t> & out);
|
||||
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t) const;
|
||||
void operator()(isaac::expression_tree const & expression_tree, size_t root_idx, leaf_t) const;
|
||||
private:
|
||||
pred_t pred_;
|
||||
std::vector<size_t> & out_;
|
||||
@@ -85,22 +85,22 @@ class filter_elements_fun : public traversal_functor
|
||||
{
|
||||
public:
|
||||
filter_elements_fun(node_type subtype, std::vector<tree_node> & out);
|
||||
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t) const;
|
||||
void operator()(isaac::expression_tree const & expression_tree, size_t root_idx, leaf_t) const;
|
||||
private:
|
||||
node_type subtype_;
|
||||
std::vector<tree_node> & out_;
|
||||
};
|
||||
|
||||
std::vector<size_t> filter_nodes(bool (*pred)(math_expression::node const & node),
|
||||
isaac::math_expression const & math_expression,
|
||||
std::vector<size_t> filter_nodes(bool (*pred)(expression_tree::node const & node),
|
||||
isaac::expression_tree const & expression_tree,
|
||||
size_t root,
|
||||
bool inspect);
|
||||
|
||||
std::vector<tree_node> filter_elements(node_type subtype,
|
||||
isaac::math_expression const & math_expression);
|
||||
isaac::expression_tree const & expression_tree);
|
||||
const char * evaluate(operation_type type);
|
||||
|
||||
/** @brief functor for generating the expression string from a math_expression */
|
||||
/** @brief functor for generating the expression string from a expression_tree */
|
||||
class evaluate_expression_traversal: public traversal_functor
|
||||
{
|
||||
private:
|
||||
@@ -110,24 +110,24 @@ private:
|
||||
|
||||
public:
|
||||
evaluate_expression_traversal(std::map<std::string, std::string> const & accessors, std::string & str, mapping_type const & mapping);
|
||||
void call_before_expansion(isaac::math_expression const & math_expression, std::size_t root_idx) const;
|
||||
void call_after_expansion(math_expression const & /*math_expression*/, std::size_t /*root_idx*/) const;
|
||||
void operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf) const;
|
||||
void call_before_expansion(isaac::expression_tree const & expression_tree, std::size_t root_idx) const;
|
||||
void call_after_expansion(expression_tree const & /*expression_tree*/, std::size_t /*root_idx*/) const;
|
||||
void operator()(isaac::expression_tree const & expression_tree, std::size_t root_idx, leaf_t leaf) const;
|
||||
};
|
||||
|
||||
std::string evaluate(leaf_t leaf, std::map<std::string, std::string> const & accessors,
|
||||
isaac::math_expression const & math_expression, std::size_t root_idx, mapping_type const & mapping);
|
||||
isaac::expression_tree const & expression_tree, std::size_t root_idx, mapping_type const & mapping);
|
||||
|
||||
void evaluate(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
|
||||
math_expression const & expressions, mapping_type const & mappings);
|
||||
expression_tree const & expressions, mapping_type const & mappings);
|
||||
|
||||
/** @brief functor for fetching or writing-back the elements in a math_expression */
|
||||
/** @brief functor for fetching or writing-back the elements in a expression_tree */
|
||||
class process_traversal : public traversal_functor
|
||||
{
|
||||
public:
|
||||
process_traversal(std::map<std::string, std::string> const & accessors, kernel_generation_stream & stream,
|
||||
mapping_type const & mapping, std::set<std::string> & already_processed);
|
||||
void operator()(math_expression const & math_expression, std::size_t root_idx, leaf_t leaf) const;
|
||||
void operator()(expression_tree const & expression_tree, std::size_t root_idx, leaf_t leaf) const;
|
||||
private:
|
||||
std::map<std::string, std::string> accessors_;
|
||||
kernel_generation_stream & stream_;
|
||||
@@ -136,21 +136,21 @@ private:
|
||||
};
|
||||
|
||||
void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
|
||||
isaac::math_expression const & math_expression, size_t root_idx, mapping_type const & mapping, std::set<std::string> & already_processed);
|
||||
isaac::expression_tree const & expression_tree, size_t root_idx, mapping_type const & mapping, std::set<std::string> & already_processed);
|
||||
|
||||
void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
|
||||
math_expression const & expressions, mapping_type const & mappings);
|
||||
expression_tree const & expressions, mapping_type const & mappings);
|
||||
|
||||
|
||||
class math_expression_representation_functor : public traversal_functor{
|
||||
class expression_tree_representation_functor : public traversal_functor{
|
||||
private:
|
||||
static void append_id(char * & ptr, unsigned int val);
|
||||
void append(driver::Buffer const & h, numeric_type dtype, char prefix) const;
|
||||
void append(tree_node const & lhs_rhs, bool is_assigned) const;
|
||||
public:
|
||||
math_expression_representation_functor(symbolic_binder & binder, char *& ptr);
|
||||
expression_tree_representation_functor(symbolic_binder & binder, char *& ptr);
|
||||
void append(char*& p, const char * str) const;
|
||||
void operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf_t) const;
|
||||
void operator()(isaac::expression_tree const & expression_tree, std::size_t root_idx, leaf_t leaf_t) const;
|
||||
private:
|
||||
symbolic_binder & binder_;
|
||||
char *& ptr_;
|
||||
|
@@ -60,20 +60,20 @@ public:
|
||||
unsigned int num_kernels;
|
||||
};
|
||||
protected:
|
||||
static int_t vector_size(math_expression::node const & node);
|
||||
static std::pair<int_t, int_t> matrix_size(math_expression::container_type const & tree, math_expression::node const & node);
|
||||
static bool requires_fallback(math_expression const & expressions);
|
||||
static int_t vector_size(expression_tree::node const & node);
|
||||
static std::pair<int_t, int_t> matrix_size(expression_tree::container_type const & tree, expression_tree::node const & node);
|
||||
static bool requires_fallback(expression_tree const & expressions);
|
||||
private:
|
||||
virtual std::string generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mapping) const = 0;
|
||||
virtual std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const & mapping) const = 0;
|
||||
public:
|
||||
base(binding_policy_t binding_policy);
|
||||
virtual unsigned int temporary_workspace(math_expression const &) const;
|
||||
virtual unsigned int lmem_usage(math_expression const &) const;
|
||||
virtual unsigned int registers_usage(math_expression const &) const;
|
||||
virtual std::vector<int_t> input_sizes(math_expression const & expressions) const = 0;
|
||||
virtual unsigned int temporary_workspace(expression_tree const &) const;
|
||||
virtual unsigned int lmem_usage(expression_tree const &) const;
|
||||
virtual unsigned int registers_usage(expression_tree const &) const;
|
||||
virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const = 0;
|
||||
virtual ~base();
|
||||
std::string generate(std::string const & suffix, math_expression const & expressions, driver::Device const & device);
|
||||
virtual int is_invalid(math_expression const & expressions, driver::Device const & device) const = 0;
|
||||
std::string generate(std::string const & suffix, expression_tree const & expressions, driver::Device const & device);
|
||||
virtual int is_invalid(expression_tree const & expressions, driver::Device const & device) const = 0;
|
||||
virtual void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const & expressions) = 0;
|
||||
virtual std::shared_ptr<base> clone() const = 0;
|
||||
private:
|
||||
@@ -85,7 +85,7 @@ template<class TemplateType, class ParametersType>
|
||||
class base_impl : public base
|
||||
{
|
||||
private:
|
||||
virtual int is_invalid_impl(driver::Device const &, math_expression const &) const;
|
||||
virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
||||
public:
|
||||
typedef ParametersType parameters_type;
|
||||
base_impl(parameters_type const & parameters, binding_policy_t binding_policy);
|
||||
@@ -93,7 +93,7 @@ public:
|
||||
unsigned int local_size_1() const;
|
||||
std::shared_ptr<base> clone() const;
|
||||
/** @brief returns whether or not the profile has undefined behavior on particular device */
|
||||
int is_invalid(math_expression const & expressions, driver::Device const & device) const;
|
||||
int is_invalid(expression_tree const & expressions, driver::Device const & device) const;
|
||||
protected:
|
||||
parameters_type p_;
|
||||
binding_policy_t binding_policy_;
|
||||
|
@@ -19,12 +19,12 @@ public:
|
||||
class elementwise_1d : public base_impl<elementwise_1d, elementwise_1d_parameters>
|
||||
{
|
||||
private:
|
||||
virtual int is_invalid_impl(driver::Device const &, math_expression const &) const;
|
||||
std::string generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mappings) const;
|
||||
virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
||||
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const & mappings) const;
|
||||
public:
|
||||
elementwise_1d(elementwise_1d::parameters_type const & parameters, binding_policy_t binding_policy = BIND_INDEPENDENT);
|
||||
elementwise_1d(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy, binding_policy_t binding_policy = BIND_INDEPENDENT);
|
||||
std::vector<int_t> input_sizes(math_expression const & expressions) const;
|
||||
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const &);
|
||||
};
|
||||
|
||||
|
@@ -22,12 +22,12 @@ public:
|
||||
class elementwise_2d : public base_impl<elementwise_2d, elementwise_2d_parameters>
|
||||
{
|
||||
private:
|
||||
int is_invalid_impl(driver::Device const &, math_expression const &) const;
|
||||
std::string generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mapping) const;
|
||||
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
||||
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const & mapping) const;
|
||||
public:
|
||||
elementwise_2d(parameters_type const & parameters, binding_policy_t binding_policy = BIND_INDEPENDENT);
|
||||
elementwise_2d(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_INDEPENDENT);
|
||||
std::vector<int_t> input_sizes(math_expression const & expressions) const;
|
||||
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const &);
|
||||
};
|
||||
|
||||
|
@@ -41,17 +41,17 @@ struct matrix_product_parameters : public base::parameters_type
|
||||
class matrix_product : public base_impl<matrix_product, matrix_product_parameters>
|
||||
{
|
||||
private:
|
||||
unsigned int temporary_workspace(math_expression const & expressions) const;
|
||||
unsigned int lmem_usage(math_expression const & expressions) const;
|
||||
unsigned int registers_usage(math_expression const & expressions) const;
|
||||
int is_invalid_impl(driver::Device const &, math_expression const &) const;
|
||||
std::string generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const &) const;
|
||||
unsigned int temporary_workspace(expression_tree const & expressions) const;
|
||||
unsigned int lmem_usage(expression_tree const & expressions) const;
|
||||
unsigned int registers_usage(expression_tree const & expressions) const;
|
||||
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
||||
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const &) const;
|
||||
void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, array_base const & A, array_base const & B, array_base const & C,
|
||||
value_scalar const &alpha, value_scalar const &beta, driver::Program const & program, std::string const & suffix, execution_options_type const & options);
|
||||
std::vector<int_t> infos(math_expression const & expressions, isaac::symbolic::preset::matrix_product::args &arguments) const;
|
||||
std::vector<int_t> infos(expression_tree const & expressions, isaac::symbolic::preset::matrix_product::args &arguments) const;
|
||||
public:
|
||||
matrix_product(matrix_product::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans);
|
||||
std::vector<int_t> input_sizes(math_expression const & expressions) const;
|
||||
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const &ctr);
|
||||
private:
|
||||
const char A_trans_;
|
||||
|
@@ -20,17 +20,17 @@ struct reduce_1d_parameters : public base::parameters_type
|
||||
class reduce_1d : public base_impl<reduce_1d, reduce_1d_parameters>
|
||||
{
|
||||
private:
|
||||
unsigned int lmem_usage(math_expression const & expressions) const;
|
||||
int is_invalid_impl(driver::Device const &, math_expression const &) const;
|
||||
unsigned int temporary_workspace(math_expression const & expressions) const;
|
||||
unsigned int lmem_usage(expression_tree const & expressions) const;
|
||||
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
||||
unsigned int temporary_workspace(expression_tree const & expressions) const;
|
||||
inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_reduce_1d*> exprs,
|
||||
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const;
|
||||
std::string generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mapping) const;
|
||||
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const & mapping) const;
|
||||
|
||||
public:
|
||||
reduce_1d(reduce_1d::parameters_type const & parameters, binding_policy_t binding_policy = BIND_INDEPENDENT);
|
||||
reduce_1d(unsigned int simd, unsigned int ls, unsigned int ng, fetching_policy_type fetch, binding_policy_t bind = BIND_INDEPENDENT);
|
||||
std::vector<int_t> input_sizes(math_expression const & expressions) const;
|
||||
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const &);
|
||||
private:
|
||||
std::vector< driver::Buffer > tmp_;
|
||||
|
@@ -31,12 +31,12 @@ protected:
|
||||
};
|
||||
reduce_2d(reduce_2d::parameters_type const & , reduce_1d_type, binding_policy_t);
|
||||
private:
|
||||
int is_invalid_impl(driver::Device const &, math_expression const &) const;
|
||||
unsigned int lmem_usage(math_expression const &) const;
|
||||
unsigned int temporary_workspace(math_expression const & expressions) const;
|
||||
std::string generate_impl(std::string const & suffix, math_expression const &, driver::Device const & device, mapping_type const &) const;
|
||||
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
||||
unsigned int lmem_usage(expression_tree const &) const;
|
||||
unsigned int temporary_workspace(expression_tree const & expressions) const;
|
||||
std::string generate_impl(std::string const & suffix, expression_tree const &, driver::Device const & device, mapping_type const &) const;
|
||||
public:
|
||||
virtual std::vector<int_t> input_sizes(math_expression const & expressions) const;
|
||||
virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const &);
|
||||
private:
|
||||
reduce_1d_type reduce_1d_type_;
|
||||
|
@@ -7,10 +7,10 @@
|
||||
namespace isaac
|
||||
{
|
||||
|
||||
/** @brief Executes a math_expression on the given queue for the given models map*/
|
||||
/** @brief Executes a expression_tree on the given queue for the given models map*/
|
||||
void execute(execution_handler const & , profiles::map_type &);
|
||||
|
||||
/** @brief Executes a math_expression on the default models map*/
|
||||
/** @brief Executes a expression_tree on the default models map*/
|
||||
void execute(execution_handler const &);
|
||||
|
||||
}
|
||||
|
@@ -142,13 +142,13 @@ struct op_element
|
||||
|
||||
struct for_idx_t
|
||||
{
|
||||
math_expression operator=(value_scalar const & ) const;
|
||||
math_expression operator=(math_expression const & ) const;
|
||||
expression_tree operator=(value_scalar const & ) const;
|
||||
expression_tree operator=(expression_tree const & ) const;
|
||||
|
||||
math_expression operator+=(value_scalar const & ) const;
|
||||
math_expression operator-=(value_scalar const & ) const;
|
||||
math_expression operator*=(value_scalar const & ) const;
|
||||
math_expression operator/=(value_scalar const & ) const;
|
||||
expression_tree operator+=(value_scalar const & ) const;
|
||||
expression_tree operator-=(value_scalar const & ) const;
|
||||
expression_tree operator*=(value_scalar const & ) const;
|
||||
expression_tree operator/=(value_scalar const & ) const;
|
||||
|
||||
int level;
|
||||
};
|
||||
@@ -171,7 +171,7 @@ struct tree_node
|
||||
{
|
||||
std::size_t node_index;
|
||||
values_holder vscalar;
|
||||
isaac::array_base* array;
|
||||
array_base* array;
|
||||
for_idx_t for_idx;
|
||||
};
|
||||
};
|
||||
@@ -180,11 +180,11 @@ struct invalid_node{};
|
||||
|
||||
void fill(tree_node &x, for_idx_t index);
|
||||
void fill(tree_node &x, invalid_node);
|
||||
void fill(tree_node & x, std::size_t node_index);
|
||||
void fill(tree_node & x, size_t node_index);
|
||||
void fill(tree_node & x, array_base const & a);
|
||||
void fill(tree_node & x, value_scalar const & v);
|
||||
|
||||
class math_expression
|
||||
class expression_tree
|
||||
{
|
||||
public:
|
||||
struct node
|
||||
@@ -197,20 +197,19 @@ public:
|
||||
typedef std::vector<node> container_type;
|
||||
|
||||
public:
|
||||
math_expression(value_scalar const &lhs, for_idx_t const &rhs, const op_element &op, const numeric_type &dtype);
|
||||
math_expression(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op);
|
||||
math_expression(for_idx_t const &lhs, value_scalar const &rhs, const op_element &op, const numeric_type &dtype);
|
||||
expression_tree(value_scalar const &lhs, for_idx_t const &rhs, const op_element &op, const numeric_type &dtype);
|
||||
expression_tree(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op);
|
||||
expression_tree(for_idx_t const &lhs, value_scalar const &rhs, const op_element &op, const numeric_type &dtype);
|
||||
|
||||
template<class LT, class RT>
|
||||
math_expression(LT const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
|
||||
expression_tree(LT const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
|
||||
template<class RT>
|
||||
math_expression(math_expression const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
|
||||
expression_tree(expression_tree const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
|
||||
template<class LT>
|
||||
math_expression(LT const & lhs, math_expression const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
|
||||
math_expression(math_expression const & lhs, math_expression const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
|
||||
expression_tree(LT const & lhs, expression_tree const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
|
||||
expression_tree(expression_tree const & lhs, expression_tree const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
|
||||
|
||||
shape_t shape() const;
|
||||
math_expression& reshape(int_t size1, int_t size2=1);
|
||||
int_t dim() const;
|
||||
container_type & tree();
|
||||
container_type const & tree() const;
|
||||
@@ -218,8 +217,8 @@ public:
|
||||
driver::Context const & context() const;
|
||||
numeric_type const & dtype() const;
|
||||
|
||||
math_expression operator-();
|
||||
math_expression operator!();
|
||||
expression_tree operator-();
|
||||
expression_tree operator!();
|
||||
private:
|
||||
container_type tree_;
|
||||
std::size_t root_;
|
||||
@@ -279,24 +278,24 @@ struct compilation_options_type
|
||||
class execution_handler
|
||||
{
|
||||
public:
|
||||
execution_handler(math_expression const & x, execution_options_type const& execution_options = execution_options_type(),
|
||||
execution_handler(expression_tree const & x, execution_options_type const& execution_options = execution_options_type(),
|
||||
dispatcher_options_type const & dispatcher_options = dispatcher_options_type(),
|
||||
compilation_options_type const & compilation_options = compilation_options_type())
|
||||
: x_(x), execution_options_(execution_options), dispatcher_options_(dispatcher_options), compilation_options_(compilation_options){}
|
||||
execution_handler(math_expression const & x, execution_handler const & other) : x_(x), execution_options_(other.execution_options_), dispatcher_options_(other.dispatcher_options_), compilation_options_(other.compilation_options_){}
|
||||
math_expression const & x() const { return x_; }
|
||||
execution_handler(expression_tree const & x, execution_handler const & other) : x_(x), execution_options_(other.execution_options_), dispatcher_options_(other.dispatcher_options_), compilation_options_(other.compilation_options_){}
|
||||
expression_tree const & x() const { return x_; }
|
||||
execution_options_type const & execution_options() const { return execution_options_; }
|
||||
dispatcher_options_type const & dispatcher_options() const { return dispatcher_options_; }
|
||||
compilation_options_type const & compilation_options() const { return compilation_options_; }
|
||||
private:
|
||||
math_expression x_;
|
||||
expression_tree x_;
|
||||
execution_options_type execution_options_;
|
||||
dispatcher_options_type dispatcher_options_;
|
||||
compilation_options_type compilation_options_;
|
||||
};
|
||||
|
||||
math_expression::node const & lhs_most(math_expression::container_type const & array_base, math_expression::node const & init);
|
||||
math_expression::node const & lhs_most(math_expression::container_type const & array_base, size_t root);
|
||||
expression_tree::node const & lhs_most(expression_tree::container_type const & array_base, expression_tree::node const & init);
|
||||
expression_tree::node const & lhs_most(expression_tree::container_type const & array_base, size_t root);
|
||||
|
||||
|
||||
}
|
||||
|
@@ -9,8 +9,8 @@ namespace isaac
|
||||
|
||||
std::string to_string(node_type const & f);
|
||||
std::string to_string(tree_node const & e);
|
||||
std::ostream & operator<<(std::ostream & os, math_expression::node const & s_node);
|
||||
std::string to_string(isaac::math_expression const & s);
|
||||
std::ostream & operator<<(std::ostream & os, expression_tree::node const & s_node);
|
||||
std::string to_string(isaac::expression_tree const & s);
|
||||
|
||||
}
|
||||
|
||||
|
@@ -34,10 +34,10 @@ public:
|
||||
};
|
||||
|
||||
private:
|
||||
static void handle_node( math_expression::container_type const &tree, size_t rootidx, args & a);
|
||||
static void handle_node( expression_tree::container_type const &tree, size_t rootidx, args & a);
|
||||
|
||||
public:
|
||||
static args check(math_expression::container_type const &tree, size_t rootidx);
|
||||
static args check(expression_tree::container_type const &tree, size_t rootidx);
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -18,13 +18,13 @@ ISAACAPI typename std::conditional<std::is_arithmetic<T>::value, value_scalar, T
|
||||
{ return wrap_generic(x); }
|
||||
|
||||
template<typename T, typename... Args>
|
||||
ISAACAPI math_expression make_tuple(driver::Context const & context, T const & x, Args... args)
|
||||
{ return math_expression(wrap_generic(x), make_tuple(context, args...), op_element(BINARY_TYPE_FAMILY, PAIR_TYPE), context, numeric_type_of(x), {1}); }
|
||||
ISAACAPI expression_tree make_tuple(driver::Context const & context, T const & x, Args... args)
|
||||
{ return expression_tree(wrap_generic(x), make_tuple(context, args...), op_element(BINARY_TYPE_FAMILY, PAIR_TYPE), context, numeric_type_of(x), {1}); }
|
||||
|
||||
inline value_scalar tuple_get(math_expression::container_type const & tree, size_t root, size_t idx)
|
||||
inline value_scalar tuple_get(expression_tree::container_type const & tree, size_t root, size_t idx)
|
||||
{
|
||||
for(unsigned int i = 0 ; i < idx ; ++i){
|
||||
math_expression::node node = tree[root];
|
||||
expression_tree::node node = tree[root];
|
||||
if(node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||
root = node.rhs.node_index;
|
||||
else
|
||||
|
@@ -9,7 +9,7 @@ namespace isaac
|
||||
{
|
||||
|
||||
class scalar;
|
||||
class math_expression;
|
||||
class expression_tree;
|
||||
|
||||
union ISAACAPI values_holder
|
||||
{
|
||||
@@ -46,7 +46,7 @@ public:
|
||||
#undef ISAAC_INSTANTIATE
|
||||
value_scalar(values_holder values, numeric_type dtype);
|
||||
explicit value_scalar(scalar const &);
|
||||
explicit value_scalar(math_expression const &);
|
||||
explicit value_scalar(expression_tree const &);
|
||||
explicit value_scalar(numeric_type dtype = INVALID_NUMERIC_TYPE);
|
||||
|
||||
values_holder values() const;
|
||||
|
300
lib/array.cpp
300
lib/array.cpp
@@ -169,7 +169,7 @@ array_base & array_base::operator=(array_base const & rhs)
|
||||
{
|
||||
if(shape_.min()==0) return *this;
|
||||
assert(dtype_ == rhs.dtype());
|
||||
math_expression expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
expression_tree expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
execute(execution_handler(expression));
|
||||
return *this;
|
||||
}
|
||||
@@ -178,7 +178,7 @@ array_base & array_base::operator=(value_scalar const & rhs)
|
||||
{
|
||||
if(shape_.min()==0) return *this;
|
||||
assert(dtype_ == rhs.dtype());
|
||||
math_expression expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
expression_tree expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
execute(execution_handler(expression));
|
||||
return *this;
|
||||
}
|
||||
@@ -188,12 +188,12 @@ array_base& array_base::operator=(execution_handler const & c)
|
||||
{
|
||||
if(shape_.min()==0) return *this;
|
||||
assert(dtype_ == c.x().dtype());
|
||||
math_expression expression(*this, c.x(), op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
expression_tree expression(*this, c.x(), op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
|
||||
execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
||||
return *this;
|
||||
}
|
||||
|
||||
array_base & array_base::operator=(math_expression const & rhs)
|
||||
array_base & array_base::operator=(expression_tree const & rhs)
|
||||
{
|
||||
return *this = execution_handler(rhs);
|
||||
}
|
||||
@@ -227,54 +227,54 @@ INSTANTIATE(double);
|
||||
|
||||
|
||||
|
||||
math_expression array_base::operator-()
|
||||
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||
expression_tree array_base::operator-()
|
||||
{ return expression_tree(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
math_expression array_base::operator!()
|
||||
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), context_, INT_TYPE, shape_); }
|
||||
expression_tree array_base::operator!()
|
||||
{ return expression_tree(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), context_, INT_TYPE, shape_); }
|
||||
|
||||
//
|
||||
array_base & array_base::operator+=(value_scalar const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator+=(array_base const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator+=(math_expression const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), rhs.context(), dtype_, shape_); }
|
||||
array_base & array_base::operator+=(expression_tree const & rhs)
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), rhs.context(), dtype_, shape_); }
|
||||
//----
|
||||
array_base & array_base::operator-=(value_scalar const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator-=(array_base const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator-=(math_expression const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), rhs.context(), dtype_, shape_); }
|
||||
array_base & array_base::operator-=(expression_tree const & rhs)
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), rhs.context(), dtype_, shape_); }
|
||||
//----
|
||||
array_base & array_base::operator*=(value_scalar const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator*=(array_base const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator*=(math_expression const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), rhs.context(), dtype_, shape_); }
|
||||
array_base & array_base::operator*=(expression_tree const & rhs)
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), rhs.context(), dtype_, shape_); }
|
||||
//----
|
||||
array_base & array_base::operator/=(value_scalar const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator/=(array_base const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
|
||||
|
||||
array_base & array_base::operator/=(math_expression const & rhs)
|
||||
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), rhs.context(), dtype_, shape_); }
|
||||
array_base & array_base::operator/=(expression_tree const & rhs)
|
||||
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), rhs.context(), dtype_, shape_); }
|
||||
|
||||
/*--- Indexing operators -----*/
|
||||
//---------------------------------------
|
||||
math_expression array_base::operator[](for_idx_t idx) const
|
||||
expression_tree array_base::operator[](for_idx_t idx) const
|
||||
{
|
||||
return math_expression(*this, idx, op_element(BINARY_TYPE_FAMILY, ACCESS_INDEX_TYPE), context_, dtype_, {1});
|
||||
return expression_tree(*this, idx, op_element(BINARY_TYPE_FAMILY, ACCESS_INDEX_TYPE), context_, dtype_, {1});
|
||||
}
|
||||
|
||||
scalar array_base::operator [](int_t idx)
|
||||
@@ -324,7 +324,7 @@ view array_base::operator()(slice const & si, slice const & sj)
|
||||
//---------------------------------------
|
||||
/*--- array ---*/
|
||||
|
||||
array::array(math_expression const & proxy) : array_base(execution_handler(proxy)) {}
|
||||
array::array(expression_tree const & proxy) : array_base(execution_handler(proxy)) {}
|
||||
|
||||
array::array(array_base const & other): array_base(other.dtype(), other.shape(), other.context())
|
||||
{ *this = other; }
|
||||
@@ -379,7 +379,7 @@ scalar::scalar(value_scalar value, driver::Context const & context) : array_base
|
||||
scalar::scalar(numeric_type dtype, driver::Context const & context) : array_base(1, dtype, context)
|
||||
{ }
|
||||
|
||||
scalar::scalar(math_expression const & proxy) : array_base(proxy){ }
|
||||
scalar::scalar(expression_tree const & proxy) : array_base(proxy){ }
|
||||
|
||||
void scalar::inject(values_holder & v) const
|
||||
{
|
||||
@@ -511,53 +511,53 @@ shape_t broadcast(shape_t const & a, shape_t const & b)
|
||||
}
|
||||
|
||||
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \
|
||||
math_expression OPNAME (array_base const & x, math_expression const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||
expression_tree OPNAME (array_base const & x, expression_tree const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||
\
|
||||
math_expression OPNAME (array_base const & x, array_base const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); }\
|
||||
expression_tree OPNAME (array_base const & x, array_base const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); }\
|
||||
\
|
||||
math_expression OPNAME (array_base const & x, value_scalar const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
expression_tree OPNAME (array_base const & x, value_scalar const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
\
|
||||
math_expression OPNAME (array_base const & x, for_idx_t const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
expression_tree OPNAME (array_base const & x, for_idx_t const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
\
|
||||
\
|
||||
math_expression OPNAME (math_expression const & x, math_expression const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||
expression_tree OPNAME (expression_tree const & x, expression_tree const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||
\
|
||||
math_expression OPNAME (math_expression const & x, array_base const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||
expression_tree OPNAME (expression_tree const & x, array_base const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
|
||||
\
|
||||
math_expression OPNAME (math_expression const & x, value_scalar const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
expression_tree OPNAME (expression_tree const & x, value_scalar const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
\
|
||||
math_expression OPNAME (math_expression const & x, for_idx_t const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
expression_tree OPNAME (expression_tree const & x, for_idx_t const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
\
|
||||
\
|
||||
math_expression OPNAME (value_scalar const & y, math_expression const & x) \
|
||||
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
expression_tree OPNAME (value_scalar const & y, expression_tree const & x) \
|
||||
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
\
|
||||
math_expression OPNAME (value_scalar const & y, array_base const & x) \
|
||||
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
expression_tree OPNAME (value_scalar const & y, array_base const & x) \
|
||||
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
\
|
||||
math_expression OPNAME (value_scalar const & x, for_idx_t const & y) \
|
||||
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); }\
|
||||
expression_tree OPNAME (value_scalar const & x, for_idx_t const & y) \
|
||||
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); }\
|
||||
\
|
||||
\
|
||||
math_expression OPNAME (for_idx_t const & y, math_expression const & x) \
|
||||
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
expression_tree OPNAME (for_idx_t const & y, expression_tree const & x) \
|
||||
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
|
||||
\
|
||||
math_expression OPNAME (for_idx_t const & y, value_scalar const & x) \
|
||||
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); } \
|
||||
expression_tree OPNAME (for_idx_t const & y, value_scalar const & x) \
|
||||
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); } \
|
||||
\
|
||||
math_expression OPNAME (for_idx_t const & y, array_base const & x) \
|
||||
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
expression_tree OPNAME (for_idx_t const & y, array_base const & x) \
|
||||
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
|
||||
\
|
||||
math_expression OPNAME (for_idx_t const & y, for_idx_t const & x) \
|
||||
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP)); }
|
||||
expression_tree OPNAME (for_idx_t const & y, for_idx_t const & x) \
|
||||
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP)); }
|
||||
|
||||
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(ADD_TYPE, operator +, x.dtype())
|
||||
@@ -580,39 +580,39 @@ DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_EQ_TYPE, operator ==, INT_TYPE)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_NEQ_TYPE, operator !=, INT_TYPE)
|
||||
|
||||
#define DEFINE_OUTER(LTYPE, RTYPE) \
|
||||
math_expression outer(LTYPE const & x, RTYPE const & y)\
|
||||
expression_tree outer(LTYPE const & x, RTYPE const & y)\
|
||||
{\
|
||||
assert(x.dim()<=1 && y.dim()<=1);\
|
||||
if(x.dim()<1 || y.dim()<1)\
|
||||
return x*y;\
|
||||
return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OUTER_PROD_TYPE), x.context(), x.dtype(), {x.shape().max(), y.shape().max()} );\
|
||||
return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OUTER_PROD_TYPE), x.context(), x.dtype(), {x.shape().max(), y.shape().max()} );\
|
||||
}\
|
||||
|
||||
DEFINE_OUTER(array_base, array_base)
|
||||
DEFINE_OUTER(math_expression, array_base)
|
||||
DEFINE_OUTER(array_base, math_expression)
|
||||
DEFINE_OUTER(math_expression, math_expression)
|
||||
DEFINE_OUTER(expression_tree, array_base)
|
||||
DEFINE_OUTER(array_base, expression_tree)
|
||||
DEFINE_OUTER(expression_tree, expression_tree)
|
||||
|
||||
#undef DEFINE_ELEMENT_BINARY_OPERATOR
|
||||
|
||||
#define DEFINE_ROT(LTYPE, RTYPE, CTYPE, STYPE)\
|
||||
math_expression rot(LTYPE const & x, RTYPE const & y, CTYPE const & c, STYPE const & s)\
|
||||
expression_tree rot(LTYPE const & x, RTYPE const & y, CTYPE const & c, STYPE const & s)\
|
||||
{ return fuse(assign(x, c*x + s*y), assign(y, c*y - s*x)); }
|
||||
|
||||
DEFINE_ROT(array_base, array_base, scalar, scalar)
|
||||
DEFINE_ROT(math_expression, array_base, scalar, scalar)
|
||||
DEFINE_ROT(array_base, math_expression, scalar, scalar)
|
||||
DEFINE_ROT(math_expression, math_expression, scalar, scalar)
|
||||
DEFINE_ROT(expression_tree, array_base, scalar, scalar)
|
||||
DEFINE_ROT(array_base, expression_tree, scalar, scalar)
|
||||
DEFINE_ROT(expression_tree, expression_tree, scalar, scalar)
|
||||
|
||||
DEFINE_ROT(array_base, array_base, value_scalar, value_scalar)
|
||||
DEFINE_ROT(math_expression, array_base, value_scalar, value_scalar)
|
||||
DEFINE_ROT(array_base, math_expression, value_scalar, value_scalar)
|
||||
DEFINE_ROT(math_expression, math_expression, value_scalar, value_scalar)
|
||||
DEFINE_ROT(expression_tree, array_base, value_scalar, value_scalar)
|
||||
DEFINE_ROT(array_base, expression_tree, value_scalar, value_scalar)
|
||||
DEFINE_ROT(expression_tree, expression_tree, value_scalar, value_scalar)
|
||||
|
||||
DEFINE_ROT(array_base, array_base, math_expression, math_expression)
|
||||
DEFINE_ROT(math_expression, array_base, math_expression, math_expression)
|
||||
DEFINE_ROT(array_base, math_expression, math_expression, math_expression)
|
||||
DEFINE_ROT(math_expression, math_expression, math_expression, math_expression)
|
||||
DEFINE_ROT(array_base, array_base, expression_tree, expression_tree)
|
||||
DEFINE_ROT(expression_tree, array_base, expression_tree, expression_tree)
|
||||
DEFINE_ROT(array_base, expression_tree, expression_tree, expression_tree)
|
||||
DEFINE_ROT(expression_tree, expression_tree, expression_tree, expression_tree)
|
||||
|
||||
|
||||
|
||||
@@ -621,11 +621,11 @@ DEFINE_ROT(math_expression, math_expression, math_expression, math_expression)
|
||||
/*--- Math Operators----*/
|
||||
//---------------------------------------
|
||||
#define DEFINE_ELEMENT_UNARY_OPERATOR(OP, OPNAME) \
|
||||
math_expression OPNAME (array_base const & x) \
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
|
||||
expression_tree OPNAME (array_base const & x) \
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
|
||||
\
|
||||
math_expression OPNAME (math_expression const & x) \
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
|
||||
expression_tree OPNAME (expression_tree const & x) \
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
|
||||
|
||||
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?FABS_TYPE:ABS_TYPE, abs)
|
||||
DEFINE_ELEMENT_UNARY_OPERATOR(ACOS_TYPE, acos)
|
||||
@@ -669,14 +669,14 @@ inline operation_type casted(numeric_type dtype)
|
||||
}
|
||||
}
|
||||
|
||||
math_expression cast(array_base const & x, numeric_type dtype)
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||
expression_tree cast(array_base const & x, numeric_type dtype)
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||
|
||||
math_expression cast(math_expression const & x, numeric_type dtype)
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||
expression_tree cast(expression_tree const & x, numeric_type dtype)
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
|
||||
|
||||
isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||
{ return math_expression(value_scalar(1), value_scalar(0), op_element(UNARY_TYPE_FAMILY, VDIAG_TYPE), ctx, dtype, {M, N}); }
|
||||
isaac::expression_tree eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||
{ return expression_tree(value_scalar(1), value_scalar(0), op_element(UNARY_TYPE_FAMILY, VDIAG_TYPE), ctx, dtype, {M, N}); }
|
||||
|
||||
array diag(array_base & x, int offset)
|
||||
{
|
||||
@@ -688,8 +688,8 @@ array diag(array_base & x, int offset)
|
||||
}
|
||||
|
||||
|
||||
isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(UNARY_TYPE_FAMILY, ADD_TYPE), ctx, dtype, {M, N}); }
|
||||
isaac::expression_tree zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||
{ return expression_tree(value_scalar(0, dtype), invalid_node(), op_element(UNARY_TYPE_FAMILY, ADD_TYPE), ctx, dtype, {M, N}); }
|
||||
|
||||
inline shape_t flip(shape_t const & shape)
|
||||
{
|
||||
@@ -702,77 +702,77 @@ inline shape_t flip(shape_t const & shape)
|
||||
//inline size4 prod(size4 const & shape1, size4 const & shape2)
|
||||
//{ return size4(shape1[0]*shape2[0], shape1[1]*shape2[1]);}
|
||||
|
||||
math_expression trans(array_base const & x) \
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
|
||||
expression_tree trans(array_base const & x) \
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
|
||||
\
|
||||
math_expression trans(math_expression const & x) \
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
|
||||
expression_tree trans(expression_tree const & x) \
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
|
||||
|
||||
math_expression repmat(array_base const & A, int_t const & rep1, int_t const & rep2)
|
||||
expression_tree repmat(array_base const & A, int_t const & rep1, int_t const & rep2)
|
||||
{
|
||||
int_t sub1 = A.shape()[0];
|
||||
int_t sub2 = A.dim()==2?A.shape()[1]:1;
|
||||
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
|
||||
return expression_tree(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
|
||||
}
|
||||
|
||||
math_expression repmat(math_expression const & A, int_t const & rep1, int_t const & rep2)
|
||||
expression_tree repmat(expression_tree const & A, int_t const & rep1, int_t const & rep2)
|
||||
{
|
||||
int_t sub1 = A.shape()[0];
|
||||
int_t sub2 = A.dim()==2?A.shape()[1]:1;
|
||||
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
|
||||
return expression_tree(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
|
||||
}
|
||||
|
||||
#define DEFINE_ACCESS_ROW(TYPEA, TYPEB) \
|
||||
math_expression row(TYPEA const & x, TYPEB const & i)\
|
||||
{ return math_expression(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_ROW_TYPE), x.context(), x.dtype(), {x.shape()[1]}); }
|
||||
expression_tree row(TYPEA const & x, TYPEB const & i)\
|
||||
{ return expression_tree(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_ROW_TYPE), x.context(), x.dtype(), {x.shape()[1]}); }
|
||||
|
||||
DEFINE_ACCESS_ROW(array_base, value_scalar)
|
||||
DEFINE_ACCESS_ROW(array_base, for_idx_t)
|
||||
DEFINE_ACCESS_ROW(array_base, math_expression)
|
||||
DEFINE_ACCESS_ROW(array_base, expression_tree)
|
||||
|
||||
DEFINE_ACCESS_ROW(math_expression, value_scalar)
|
||||
DEFINE_ACCESS_ROW(math_expression, for_idx_t)
|
||||
DEFINE_ACCESS_ROW(math_expression, math_expression)
|
||||
DEFINE_ACCESS_ROW(expression_tree, value_scalar)
|
||||
DEFINE_ACCESS_ROW(expression_tree, for_idx_t)
|
||||
DEFINE_ACCESS_ROW(expression_tree, expression_tree)
|
||||
|
||||
#define DEFINE_ACCESS_COL(TYPEA, TYPEB) \
|
||||
math_expression col(TYPEA const & x, TYPEB const & i)\
|
||||
{ return math_expression(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_COLUMN_TYPE), x.context(), x.dtype(), {x.shape()[0]}); }
|
||||
expression_tree col(TYPEA const & x, TYPEB const & i)\
|
||||
{ return expression_tree(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_COLUMN_TYPE), x.context(), x.dtype(), {x.shape()[0]}); }
|
||||
|
||||
DEFINE_ACCESS_COL(array_base, value_scalar)
|
||||
DEFINE_ACCESS_COL(array_base, for_idx_t)
|
||||
DEFINE_ACCESS_COL(array_base, math_expression)
|
||||
DEFINE_ACCESS_COL(array_base, expression_tree)
|
||||
|
||||
DEFINE_ACCESS_COL(math_expression, value_scalar)
|
||||
DEFINE_ACCESS_COL(math_expression, for_idx_t)
|
||||
DEFINE_ACCESS_COL(math_expression, math_expression)
|
||||
DEFINE_ACCESS_COL(expression_tree, value_scalar)
|
||||
DEFINE_ACCESS_COL(expression_tree, for_idx_t)
|
||||
DEFINE_ACCESS_COL(expression_tree, expression_tree)
|
||||
|
||||
////---------------------------------------
|
||||
|
||||
///*--- Reductions ---*/
|
||||
////---------------------------------------
|
||||
#define DEFINE_REDUCTION(OP, OPNAME)\
|
||||
math_expression OPNAME(array_base const & x, int_t axis)\
|
||||
expression_tree OPNAME(array_base const & x, int_t axis)\
|
||||
{\
|
||||
if(axis < -1 || axis > x.dim())\
|
||||
throw std::out_of_range("The axis entry is out of bounds");\
|
||||
else if(axis==-1)\
|
||||
return math_expression(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
|
||||
return expression_tree(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
|
||||
else if(axis==0)\
|
||||
return math_expression(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
|
||||
return expression_tree(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
|
||||
else\
|
||||
return math_expression(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
|
||||
return expression_tree(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
|
||||
}\
|
||||
\
|
||||
math_expression OPNAME(math_expression const & x, int_t axis)\
|
||||
expression_tree OPNAME(expression_tree const & x, int_t axis)\
|
||||
{\
|
||||
if(axis < -1 || axis > x.dim())\
|
||||
throw std::out_of_range("The axis entry is out of bounds");\
|
||||
if(axis==-1)\
|
||||
return math_expression(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
|
||||
return expression_tree(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
|
||||
else if(axis==0)\
|
||||
return math_expression(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
|
||||
return expression_tree(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
|
||||
else\
|
||||
return math_expression(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
|
||||
return expression_tree(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
|
||||
}
|
||||
|
||||
DEFINE_REDUCTION(ADD_TYPE, sum)
|
||||
@@ -786,51 +786,51 @@ DEFINE_REDUCTION(ELEMENT_ARGMIN_TYPE, argmin)
|
||||
namespace detail
|
||||
{
|
||||
|
||||
math_expression matmatprod(array_base const & A, array_base const & B)
|
||||
expression_tree matmatprod(array_base const & A, array_base const & B)
|
||||
{
|
||||
shape_t shape{A.shape()[0], B.shape()[1]};
|
||||
return math_expression(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, MATRIX_PRODUCT_NN_TYPE), A.context(), A.dtype(), shape);
|
||||
return expression_tree(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, MATRIX_PRODUCT_NN_TYPE), A.context(), A.dtype(), shape);
|
||||
}
|
||||
|
||||
math_expression matmatprod(math_expression const & A, array_base const & B)
|
||||
expression_tree matmatprod(expression_tree const & A, array_base const & B)
|
||||
{
|
||||
operation_type type = MATRIX_PRODUCT_NN_TYPE;
|
||||
shape_t shape{A.shape()[0], B.shape()[1]};
|
||||
|
||||
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
||||
expression_tree::node & A_root = const_cast<expression_tree::node &>(A.tree()[A.root()]);
|
||||
bool A_trans = A_root.op.type==TRANS_TYPE;
|
||||
if(A_trans){
|
||||
type = MATRIX_PRODUCT_TN_TYPE;
|
||||
}
|
||||
|
||||
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
||||
expression_tree res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
expression_tree::node & res_root = const_cast<expression_tree::node &>(res.tree()[res.root()]);
|
||||
if(A_trans) res_root.lhs = A_root.lhs;
|
||||
return res;
|
||||
}
|
||||
|
||||
math_expression matmatprod(array_base const & A, math_expression const & B)
|
||||
expression_tree matmatprod(array_base const & A, expression_tree const & B)
|
||||
{
|
||||
operation_type type = MATRIX_PRODUCT_NN_TYPE;
|
||||
shape_t shape{A.shape()[0], B.shape()[1]};
|
||||
|
||||
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
|
||||
expression_tree::node & B_root = const_cast<expression_tree::node &>(B.tree()[B.root()]);
|
||||
bool B_trans = B_root.op.type==TRANS_TYPE;
|
||||
if(B_trans){
|
||||
type = MATRIX_PRODUCT_NT_TYPE;
|
||||
}
|
||||
|
||||
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
||||
expression_tree res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
expression_tree::node & res_root = const_cast<expression_tree::node &>(res.tree()[res.root()]);
|
||||
if(B_trans) res_root.rhs = B_root.lhs;
|
||||
return res;
|
||||
}
|
||||
|
||||
math_expression matmatprod(math_expression const & A, math_expression const & B)
|
||||
expression_tree matmatprod(expression_tree const & A, expression_tree const & B)
|
||||
{
|
||||
operation_type type = MATRIX_PRODUCT_NN_TYPE;
|
||||
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
||||
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
|
||||
expression_tree::node & A_root = const_cast<expression_tree::node &>(A.tree()[A.root()]);
|
||||
expression_tree::node & B_root = const_cast<expression_tree::node &>(B.tree()[B.root()]);
|
||||
shape_t shape{A.shape()[0], B.shape()[1]};
|
||||
|
||||
bool A_trans = A_root.op.type==TRANS_TYPE;
|
||||
@@ -841,15 +841,15 @@ namespace detail
|
||||
else if(!A_trans && B_trans) type = MATRIX_PRODUCT_NT_TYPE;
|
||||
else type = MATRIX_PRODUCT_NN_TYPE;
|
||||
|
||||
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
|
||||
expression_tree res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||
expression_tree::node & res_root = const_cast<expression_tree::node &>(res.tree()[res.root()]);
|
||||
if(A_trans) res_root.lhs = A_root.lhs;
|
||||
if(B_trans) res_root.rhs = B_root.lhs;
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
math_expression matvecprod(array_base const & A, T const & x)
|
||||
expression_tree matvecprod(array_base const & A, T const & x)
|
||||
{
|
||||
int_t M = A.shape()[0];
|
||||
int_t N = A.shape()[1];
|
||||
@@ -857,11 +857,11 @@ namespace detail
|
||||
}
|
||||
|
||||
template<class T>
|
||||
math_expression matvecprod(math_expression const & A, T const & x)
|
||||
expression_tree matvecprod(expression_tree const & A, T const & x)
|
||||
{
|
||||
int_t M = A.shape()[0];
|
||||
int_t N = A.shape()[1];
|
||||
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
|
||||
expression_tree::node & A_root = const_cast<expression_tree::node &>(A.tree()[A.root()]);
|
||||
bool A_trans = A_root.op.type==TRANS_TYPE;
|
||||
while(A_root.lhs.subtype==COMPOSITE_OPERATOR_TYPE){
|
||||
A_root = A.tree()[A_root.lhs.node_index];
|
||||
@@ -869,7 +869,7 @@ namespace detail
|
||||
}
|
||||
if(A_trans)
|
||||
{
|
||||
math_expression tmp(A, repmat(x, 1, M), op_element(BINARY_TYPE_FAMILY, ELEMENT_PROD_TYPE), A.context(), A.dtype(), {N, M});
|
||||
expression_tree tmp(A, repmat(x, 1, M), op_element(BINARY_TYPE_FAMILY, ELEMENT_PROD_TYPE), A.context(), A.dtype(), {N, M});
|
||||
//Remove trans
|
||||
tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs;
|
||||
return sum(tmp, 0);
|
||||
@@ -889,17 +889,17 @@ ISAACAPI void swap(view x, view y)
|
||||
}
|
||||
|
||||
//Reshape
|
||||
math_expression reshape(array_base const & x, shape_t const & shape)
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
|
||||
expression_tree reshape(array_base const & x, shape_t const & shape)
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
|
||||
|
||||
math_expression reshape(math_expression const & x, shape_t const & shape)
|
||||
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
|
||||
expression_tree reshape(expression_tree const & x, shape_t const & shape)
|
||||
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
|
||||
|
||||
math_expression ravel(array_base const & x)
|
||||
expression_tree ravel(array_base const & x)
|
||||
{ return reshape(x, {x.shape().prod()}); }
|
||||
|
||||
#define DEFINE_DOT(LTYPE, RTYPE) \
|
||||
math_expression dot(LTYPE const & x, RTYPE const & y)\
|
||||
expression_tree dot(LTYPE const & x, RTYPE const & y)\
|
||||
{\
|
||||
numeric_type dtype = x.dtype();\
|
||||
driver::Context const & context = x.context();\
|
||||
@@ -908,7 +908,7 @@ math_expression dot(LTYPE const & x, RTYPE const & y)\
|
||||
if(x.dim()==2 && x.shape()[1]==0)\
|
||||
return zeros(x.shape()[0], y.shape()[1], dtype, context);\
|
||||
if(x.shape()[0]==0 || (y.dim()==2 && y.shape()[1]==0))\
|
||||
return math_expression(invalid_node(), invalid_node(), op_element(UNARY_TYPE_FAMILY, INVALID_TYPE), context, dtype, {0});\
|
||||
return expression_tree(invalid_node(), invalid_node(), op_element(UNARY_TYPE_FAMILY, INVALID_TYPE), context, dtype, {0});\
|
||||
if(x.dim()==1 && y.dim()==1)\
|
||||
return sum(x*y);\
|
||||
if(x.dim()==2 && x.shape()[0]==1 && y.dim()==1){\
|
||||
@@ -940,15 +940,15 @@ math_expression dot(LTYPE const & x, RTYPE const & y)\
|
||||
}
|
||||
|
||||
DEFINE_DOT(array_base, array_base)
|
||||
DEFINE_DOT(math_expression, array_base)
|
||||
DEFINE_DOT(array_base, math_expression)
|
||||
DEFINE_DOT(math_expression, math_expression)
|
||||
DEFINE_DOT(expression_tree, array_base)
|
||||
DEFINE_DOT(array_base, expression_tree)
|
||||
DEFINE_DOT(expression_tree, expression_tree)
|
||||
|
||||
#undef DEFINE_DOT
|
||||
|
||||
|
||||
#define DEFINE_NORM(TYPE)\
|
||||
math_expression norm(TYPE const & x, unsigned int order)\
|
||||
expression_tree norm(TYPE const & x, unsigned int order)\
|
||||
{\
|
||||
assert(order > 0 && order < 3);\
|
||||
switch(order)\
|
||||
@@ -959,21 +959,21 @@ math_expression norm(TYPE const & x, unsigned int order)\
|
||||
}
|
||||
|
||||
DEFINE_NORM(array_base)
|
||||
DEFINE_NORM(math_expression)
|
||||
DEFINE_NORM(expression_tree)
|
||||
|
||||
#undef DEFINE_NORM
|
||||
|
||||
/*--- Fusion ----*/
|
||||
math_expression fuse(math_expression const & x, math_expression const & y)
|
||||
expression_tree fuse(expression_tree const & x, expression_tree const & y)
|
||||
{
|
||||
assert(x.context()==y.context());
|
||||
return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape());
|
||||
return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape());
|
||||
}
|
||||
|
||||
/*--- For loops ---*/
|
||||
ISAACAPI math_expression sfor(math_expression const & start, math_expression const & end, math_expression const & inc, math_expression const & x)
|
||||
ISAACAPI expression_tree sfor(expression_tree const & start, expression_tree const & end, expression_tree const & inc, expression_tree const & x)
|
||||
{
|
||||
return math_expression(x, make_tuple(x.context(), start, end, inc), op_element(UNARY_TYPE_FAMILY, SFOR_TYPE), x.context(), x.dtype(), x.shape());
|
||||
return expression_tree(x, make_tuple(x.context(), start, end, inc), op_element(UNARY_TYPE_FAMILY, SFOR_TYPE), x.context(), x.dtype(), x.shape());
|
||||
}
|
||||
|
||||
|
||||
@@ -1160,7 +1160,7 @@ std::ostream& operator<<(std::ostream & os, array_base const & a)
|
||||
return os;
|
||||
}
|
||||
|
||||
ISAACAPI std::ostream& operator<<(std::ostream & oss, math_expression const & expression)
|
||||
ISAACAPI std::ostream& operator<<(std::ostream & oss, expression_tree const & expression)
|
||||
{
|
||||
return oss << array(expression);
|
||||
}
|
||||
|
@@ -51,8 +51,8 @@ void mapped_object::register_attribute(std::string & attribute, std::string cons
|
||||
keywords_[key] = attribute;
|
||||
}
|
||||
|
||||
mapped_object::node_info::node_info(mapping_type const * _mapping, isaac::math_expression const * _math_expression, size_t _root_idx) :
|
||||
mapping(_mapping), math_expression(_math_expression), root_idx(_root_idx) { }
|
||||
mapped_object::node_info::node_info(mapping_type const * _mapping, isaac::expression_tree const * _expression_tree, size_t _root_idx) :
|
||||
mapping(_mapping), expression_tree(_expression_tree), root_idx(_root_idx) { }
|
||||
|
||||
mapped_object::mapped_object(std::string const & scalartype, std::string const & name, std::string const & type_key) : type_key_(type_key), name_(name)
|
||||
{
|
||||
@@ -92,10 +92,10 @@ std::string mapped_object::evaluate(std::map<std::string, std::string> const & a
|
||||
return process(accessors.at(type_key_));
|
||||
}
|
||||
|
||||
mapped_object& get(math_expression::container_type const & tree, size_t root, mapping_type const & mapping, size_t idx)
|
||||
mapped_object& get(expression_tree::container_type const & tree, size_t root, mapping_type const & mapping, size_t idx)
|
||||
{
|
||||
for(unsigned int i = 0 ; i < idx ; ++i){
|
||||
math_expression::node node = tree[root];
|
||||
expression_tree::node node = tree[root];
|
||||
if(node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||
root = node.rhs.node_index;
|
||||
else
|
||||
@@ -108,12 +108,12 @@ binary_leaf::binary_leaf(mapped_object::node_info info) : info_(info){ }
|
||||
|
||||
void binary_leaf::process_recursive(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors, std::set<std::string> & already_fetched)
|
||||
{
|
||||
process(stream, leaf, accessors, *info_.math_expression, info_.root_idx, *info_.mapping, already_fetched);
|
||||
process(stream, leaf, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping, already_fetched);
|
||||
}
|
||||
|
||||
std::string binary_leaf::evaluate_recursive(leaf_t leaf, std::map<std::string, std::string> const & accessors)
|
||||
{
|
||||
return evaluate(leaf, accessors, *info_.math_expression, info_.root_idx, *info_.mapping);
|
||||
return evaluate(leaf, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping);
|
||||
}
|
||||
|
||||
|
||||
@@ -127,11 +127,11 @@ mapped_reduce::mapped_reduce(std::string const & scalartype, unsigned int id, no
|
||||
size_t mapped_reduce::root_idx() const
|
||||
{ return info_.root_idx; }
|
||||
|
||||
isaac::math_expression const & mapped_reduce::math_expression() const
|
||||
{ return *info_.math_expression; }
|
||||
isaac::expression_tree const & mapped_reduce::expression_tree() const
|
||||
{ return *info_.expression_tree; }
|
||||
|
||||
math_expression::node mapped_reduce::root_node() const
|
||||
{ return math_expression().tree()[root_idx()]; }
|
||||
expression_tree::node mapped_reduce::root_node() const
|
||||
{ return expression_tree().tree()[root_idx()]; }
|
||||
|
||||
bool mapped_reduce::is_index_reduction() const
|
||||
{
|
||||
@@ -144,7 +144,7 @@ bool mapped_reduce::is_index_reduction() const
|
||||
|
||||
op_element mapped_reduce::root_op() const
|
||||
{
|
||||
return info_.math_expression->tree()[info_.root_idx].op;
|
||||
return info_.expression_tree->tree()[info_.root_idx].op;
|
||||
}
|
||||
|
||||
|
||||
@@ -214,10 +214,10 @@ mapped_array::mapped_array(std::string const & scalartype, unsigned int id, std:
|
||||
void mapped_vdiag::postprocess(std::string &res) const
|
||||
{
|
||||
std::map<std::string, std::string> accessors;
|
||||
tools::find_and_replace(res, "#diag_offset", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping));
|
||||
tools::find_and_replace(res, "#diag_offset", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping));
|
||||
accessors["arrayn"] = res;
|
||||
accessors["host_scalar"] = res;
|
||||
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping);
|
||||
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping);
|
||||
}
|
||||
|
||||
mapped_vdiag::mapped_vdiag(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "vdiag"), binary_leaf(info){}
|
||||
@@ -226,10 +226,10 @@ mapped_vdiag::mapped_vdiag(std::string const & scalartype, unsigned int id, node
|
||||
void mapped_array_access::postprocess(std::string &res) const
|
||||
{
|
||||
std::map<std::string, std::string> accessors;
|
||||
tools::find_and_replace(res, "#index", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping));
|
||||
tools::find_and_replace(res, "#index", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping));
|
||||
accessors["arrayn"] = res;
|
||||
accessors["arraynn"] = res;
|
||||
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping);
|
||||
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping);
|
||||
}
|
||||
|
||||
mapped_array_access::mapped_array_access(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "array_access"), binary_leaf(info)
|
||||
@@ -239,9 +239,9 @@ mapped_array_access::mapped_array_access(std::string const & scalartype, unsigne
|
||||
void mapped_matrix_row::postprocess(std::string &res) const
|
||||
{
|
||||
std::map<std::string, std::string> accessors;
|
||||
tools::find_and_replace(res, "#row", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping));
|
||||
tools::find_and_replace(res, "#row", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping));
|
||||
accessors["arraynn"] = res;
|
||||
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping);
|
||||
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping);
|
||||
}
|
||||
|
||||
mapped_matrix_row::mapped_matrix_row(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "matrix_row"), binary_leaf(info)
|
||||
@@ -251,9 +251,9 @@ mapped_matrix_row::mapped_matrix_row(std::string const & scalartype, unsigned in
|
||||
void mapped_matrix_column::postprocess(std::string &res) const
|
||||
{
|
||||
std::map<std::string, std::string> accessors;
|
||||
tools::find_and_replace(res, "#column", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping));
|
||||
tools::find_and_replace(res, "#column", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping));
|
||||
accessors["arraynn"] = res;
|
||||
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping);
|
||||
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping);
|
||||
}
|
||||
|
||||
mapped_matrix_column::mapped_matrix_column(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "matrix_column"), binary_leaf(info)
|
||||
@@ -265,7 +265,7 @@ mapped_matrix_column::mapped_matrix_column(std::string const & scalartype, unsig
|
||||
//
|
||||
char mapped_repeat::get_type(node_info const & info)
|
||||
{
|
||||
math_expression::container_type const & tree = info.math_expression->tree();
|
||||
expression_tree::container_type const & tree = info.expression_tree->tree();
|
||||
size_t tuple_root = tree[info.root_idx].rhs.node_index;
|
||||
|
||||
int sub0 = tuple_get(tree, tuple_root, 2);
|
||||
@@ -283,7 +283,7 @@ char mapped_repeat::get_type(node_info const & info)
|
||||
void mapped_repeat::postprocess(std::string &res) const
|
||||
{
|
||||
std::map<std::string, std::string> accessors;
|
||||
math_expression::container_type const & tree = info_.math_expression->tree();
|
||||
expression_tree::container_type const & tree = info_.expression_tree->tree();
|
||||
size_t tuple_root = tree[info_.root_idx].rhs.node_index;
|
||||
|
||||
tools::find_and_replace(res, "#rep0", get(tree, tuple_root, *info_.mapping, 0).process("#name"));
|
||||
@@ -312,7 +312,7 @@ void mapped_repeat::postprocess(std::string &res) const
|
||||
accessors["array1n"] = res;
|
||||
accessors["arrayn1"] = res;
|
||||
accessors["arraynn"] = res;
|
||||
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping);
|
||||
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping);
|
||||
}
|
||||
|
||||
mapped_repeat::mapped_repeat(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "repeat"), binary_leaf(info), type_(get_type(info))
|
||||
@@ -324,9 +324,9 @@ mapped_repeat::mapped_repeat(std::string const & scalartype, unsigned int id, no
|
||||
void mapped_matrix_diag::postprocess(std::string &res) const
|
||||
{
|
||||
std::map<std::string, std::string> accessors;
|
||||
tools::find_and_replace(res, "#diag_offset", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping));
|
||||
tools::find_and_replace(res, "#diag_offset", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping));
|
||||
accessors["arraynn"] = res;
|
||||
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping);
|
||||
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping);
|
||||
}
|
||||
|
||||
mapped_matrix_diag::mapped_matrix_diag(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "matrix_diag"), binary_leaf(info)
|
||||
@@ -343,7 +343,7 @@ void mapped_outer::postprocess(std::string &res) const
|
||||
std::map<std::string, std::string> accessors;
|
||||
accessors["arrayn"] = "$VALUE{"+i+"}";
|
||||
accessors["array1"] = "#namereg";
|
||||
return isaac::evaluate(leaf_, accessors, *i_.math_expression, i_.root_idx, *i_.mapping);
|
||||
return isaac::evaluate(leaf_, accessors, *i_.expression_tree, i_.root_idx, *i_.mapping);
|
||||
}
|
||||
std::string operator()(std::string const &, std::string const &) const{return "";}
|
||||
private:
|
||||
|
@@ -14,12 +14,12 @@ namespace detail
|
||||
|
||||
|
||||
|
||||
bool is_scalar_reduce_1d(math_expression::node const & node)
|
||||
bool is_scalar_reduce_1d(expression_tree::node const & node)
|
||||
{
|
||||
return node.op.type_family==VECTOR_DOT_TYPE_FAMILY;
|
||||
}
|
||||
|
||||
bool is_vector_reduce_1d(math_expression::node const & node)
|
||||
bool is_vector_reduce_1d(expression_tree::node const & node)
|
||||
{
|
||||
return node.op.type_family==ROWS_DOT_TYPE_FAMILY
|
||||
|| node.op.type_family==COLUMNS_DOT_TYPE_FAMILY;
|
||||
@@ -123,18 +123,18 @@ namespace detail
|
||||
filter_fun::filter_fun(pred_t pred, std::vector<size_t> & out) : pred_(pred), out_(out)
|
||||
{ }
|
||||
|
||||
void filter_fun::operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t leaf) const
|
||||
void filter_fun::operator()(isaac::expression_tree const & expression_tree, size_t root_idx, leaf_t leaf) const
|
||||
{
|
||||
math_expression::node const * root_node = &math_expression.tree()[root_idx];
|
||||
expression_tree::node const * root_node = &expression_tree.tree()[root_idx];
|
||||
if (leaf==PARENT_NODE_TYPE && pred_(*root_node))
|
||||
out_.push_back(root_idx);
|
||||
}
|
||||
|
||||
//
|
||||
std::vector<size_t> filter_nodes(bool (*pred)(math_expression::node const & node), isaac::math_expression const & math_expression, size_t root, bool inspect)
|
||||
std::vector<size_t> filter_nodes(bool (*pred)(expression_tree::node const & node), isaac::expression_tree const & expression_tree, size_t root, bool inspect)
|
||||
{
|
||||
std::vector<size_t> res;
|
||||
traverse(math_expression, root, filter_fun(pred, res), inspect);
|
||||
traverse(expression_tree, root, filter_fun(pred, res), inspect);
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -143,9 +143,9 @@ filter_elements_fun::filter_elements_fun(node_type subtype, std::vector<tree_nod
|
||||
subtype_(subtype), out_(out)
|
||||
{ }
|
||||
|
||||
void filter_elements_fun::operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t) const
|
||||
void filter_elements_fun::operator()(isaac::expression_tree const & expression_tree, size_t root_idx, leaf_t) const
|
||||
{
|
||||
math_expression::node const * root_node = &math_expression.tree()[root_idx];
|
||||
expression_tree::node const * root_node = &expression_tree.tree()[root_idx];
|
||||
if (root_node->lhs.subtype==subtype_)
|
||||
out_.push_back(root_node->lhs);
|
||||
if (root_node->rhs.subtype==subtype_)
|
||||
@@ -153,10 +153,10 @@ void filter_elements_fun::operator()(isaac::math_expression const & math_express
|
||||
}
|
||||
|
||||
|
||||
std::vector<tree_node> filter_elements(node_type subtype, isaac::math_expression const & math_expression)
|
||||
std::vector<tree_node> filter_elements(node_type subtype, isaac::expression_tree const & expression_tree)
|
||||
{
|
||||
std::vector<tree_node> res;
|
||||
traverse(math_expression, math_expression.root(), filter_elements_fun(subtype, res), true);
|
||||
traverse(expression_tree, expression_tree.root(), filter_elements_fun(subtype, res), true);
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -240,9 +240,9 @@ evaluate_expression_traversal::evaluate_expression_traversal(std::map<std::strin
|
||||
accessors_(accessors), str_(str), mapping_(mapping)
|
||||
{ }
|
||||
|
||||
void evaluate_expression_traversal::call_before_expansion(isaac::math_expression const & math_expression, std::size_t root_idx) const
|
||||
void evaluate_expression_traversal::call_before_expansion(isaac::expression_tree const & expression_tree, std::size_t root_idx) const
|
||||
{
|
||||
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
||||
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
|
||||
if(detail::is_cast(root_node.op))
|
||||
str_ += mapping_.at(std::make_pair(root_idx, PARENT_NODE_TYPE))->evaluate(accessors_);
|
||||
else if (( (root_node.op.type_family==UNARY_TYPE_FAMILY&&root_node.op.type!=ADD_TYPE) || detail::is_elementwise_function(root_node.op))
|
||||
@@ -253,16 +253,16 @@ void evaluate_expression_traversal::call_before_expansion(isaac::math_expression
|
||||
|
||||
}
|
||||
|
||||
void evaluate_expression_traversal::call_after_expansion(math_expression const & math_expression, std::size_t root_idx) const
|
||||
void evaluate_expression_traversal::call_after_expansion(expression_tree const & expression_tree, std::size_t root_idx) const
|
||||
{
|
||||
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
||||
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
|
||||
if(root_node.op.type!=OPERATOR_FUSE)
|
||||
str_+=")";
|
||||
}
|
||||
|
||||
void evaluate_expression_traversal::operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf) const
|
||||
void evaluate_expression_traversal::operator()(isaac::expression_tree const & expression_tree, std::size_t root_idx, leaf_t leaf) const
|
||||
{
|
||||
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
||||
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
|
||||
mapping_type::key_type key = std::make_pair(root_idx, leaf);
|
||||
if (leaf==PARENT_NODE_TYPE)
|
||||
{
|
||||
@@ -304,34 +304,34 @@ void evaluate_expression_traversal::operator()(isaac::math_expression const & ma
|
||||
|
||||
|
||||
std::string evaluate(leaf_t leaf, std::map<std::string, std::string> const & accessors,
|
||||
isaac::math_expression const & math_expression, std::size_t root_idx, mapping_type const & mapping)
|
||||
isaac::expression_tree const & expression_tree, std::size_t root_idx, mapping_type const & mapping)
|
||||
{
|
||||
std::string res;
|
||||
evaluate_expression_traversal traversal_functor(accessors, res, mapping);
|
||||
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
||||
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
|
||||
|
||||
if (leaf==RHS_NODE_TYPE)
|
||||
{
|
||||
if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||
traverse(math_expression, root_node.rhs.node_index, traversal_functor, false);
|
||||
traverse(expression_tree, root_node.rhs.node_index, traversal_functor, false);
|
||||
else
|
||||
traversal_functor(math_expression, root_idx, leaf);
|
||||
traversal_functor(expression_tree, root_idx, leaf);
|
||||
}
|
||||
else if (leaf==LHS_NODE_TYPE)
|
||||
{
|
||||
if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||
traverse(math_expression, root_node.lhs.node_index, traversal_functor, false);
|
||||
traverse(expression_tree, root_node.lhs.node_index, traversal_functor, false);
|
||||
else
|
||||
traversal_functor(math_expression, root_idx, leaf);
|
||||
traversal_functor(expression_tree, root_idx, leaf);
|
||||
}
|
||||
else
|
||||
traverse(math_expression, root_idx, traversal_functor, false);
|
||||
traverse(expression_tree, root_idx, traversal_functor, false);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void evaluate(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
|
||||
math_expression const & x, mapping_type const & mapping)
|
||||
expression_tree const & x, mapping_type const & mapping)
|
||||
{
|
||||
stream << evaluate(leaf, accessors, x, x.root(), mapping) << std::endl;
|
||||
}
|
||||
@@ -341,7 +341,7 @@ process_traversal::process_traversal(std::map<std::string, std::string> const &
|
||||
accessors_(accessors), stream_(stream), mapping_(mapping), already_processed_(already_processed)
|
||||
{ }
|
||||
|
||||
void process_traversal::operator()(math_expression const & /*math_expression*/, std::size_t root_idx, leaf_t leaf) const
|
||||
void process_traversal::operator()(expression_tree const & /*expression_tree*/, std::size_t root_idx, leaf_t leaf) const
|
||||
{
|
||||
mapping_type::const_iterator it = mapping_.find(std::make_pair(root_idx, leaf));
|
||||
if (it!=mapping_.end())
|
||||
@@ -362,41 +362,41 @@ void process_traversal::operator()(math_expression const & /*math_expression*/,
|
||||
|
||||
|
||||
void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
|
||||
isaac::math_expression const & math_expression, size_t root_idx, mapping_type const & mapping, std::set<std::string> & already_processed)
|
||||
isaac::expression_tree const & expression_tree, size_t root_idx, mapping_type const & mapping, std::set<std::string> & already_processed)
|
||||
{
|
||||
process_traversal traversal_functor(accessors, stream, mapping, already_processed);
|
||||
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
||||
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
|
||||
|
||||
if (leaf==RHS_NODE_TYPE)
|
||||
{
|
||||
if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||
traverse(math_expression, root_node.rhs.node_index, traversal_functor, true);
|
||||
traverse(expression_tree, root_node.rhs.node_index, traversal_functor, true);
|
||||
else
|
||||
traversal_functor(math_expression, root_idx, leaf);
|
||||
traversal_functor(expression_tree, root_idx, leaf);
|
||||
}
|
||||
else if (leaf==LHS_NODE_TYPE)
|
||||
{
|
||||
if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||
traverse(math_expression, root_node.lhs.node_index, traversal_functor, true);
|
||||
traverse(expression_tree, root_node.lhs.node_index, traversal_functor, true);
|
||||
else
|
||||
traversal_functor(math_expression, root_idx, leaf);
|
||||
traversal_functor(expression_tree, root_idx, leaf);
|
||||
}
|
||||
else
|
||||
{
|
||||
traverse(math_expression, root_idx, traversal_functor, true);
|
||||
traverse(expression_tree, root_idx, traversal_functor, true);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
|
||||
math_expression const & x, mapping_type const & mapping)
|
||||
expression_tree const & x, mapping_type const & mapping)
|
||||
{
|
||||
std::set<std::string> processed;
|
||||
process(stream, leaf, accessors, x, x.root(), mapping, processed);
|
||||
}
|
||||
|
||||
|
||||
void math_expression_representation_functor::append_id(char * & ptr, unsigned int val)
|
||||
void expression_tree_representation_functor::append_id(char * & ptr, unsigned int val)
|
||||
{
|
||||
if (val==0)
|
||||
*ptr++='0';
|
||||
@@ -408,14 +408,14 @@ void math_expression_representation_functor::append_id(char * & ptr, unsigned in
|
||||
}
|
||||
}
|
||||
|
||||
//void math_expression_representation_functor::append(driver::Buffer const & h, numeric_type dtype, char prefix, bool is_assigned) const
|
||||
//void expression_tree_representation_functor::append(driver::Buffer const & h, numeric_type dtype, char prefix, bool is_assigned) const
|
||||
//{
|
||||
// *ptr_++=prefix;
|
||||
// *ptr_++=(char)dtype;
|
||||
// append_id(ptr_, binder_.get(h, is_assigned));
|
||||
//}
|
||||
|
||||
void math_expression_representation_functor::append(tree_node const & lhs_rhs, bool is_assigned) const
|
||||
void expression_tree_representation_functor::append(tree_node const & lhs_rhs, bool is_assigned) const
|
||||
{
|
||||
if(lhs_rhs.subtype==DENSE_ARRAY_TYPE)
|
||||
{
|
||||
@@ -428,18 +428,18 @@ void math_expression_representation_functor::append(tree_node const & lhs_rhs, b
|
||||
}
|
||||
}
|
||||
|
||||
math_expression_representation_functor::math_expression_representation_functor(symbolic_binder & binder, char *& ptr) : binder_(binder), ptr_(ptr){ }
|
||||
expression_tree_representation_functor::expression_tree_representation_functor(symbolic_binder & binder, char *& ptr) : binder_(binder), ptr_(ptr){ }
|
||||
|
||||
void math_expression_representation_functor::append(char*& p, const char * str) const
|
||||
void expression_tree_representation_functor::append(char*& p, const char * str) const
|
||||
{
|
||||
std::size_t n = std::strlen(str);
|
||||
std::memcpy(p, str, n);
|
||||
p+=n;
|
||||
}
|
||||
|
||||
void math_expression_representation_functor::operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf_t) const
|
||||
void expression_tree_representation_functor::operator()(isaac::expression_tree const & expression_tree, std::size_t root_idx, leaf_t leaf_t) const
|
||||
{
|
||||
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
||||
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
|
||||
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
|
||||
append(root_node.lhs, detail::is_assignment(root_node.op));
|
||||
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.subtype != COMPOSITE_OPERATOR_TYPE)
|
||||
|
@@ -28,16 +28,16 @@ base::parameters_type::parameters_type(unsigned int _simd_width, int_t _local_si
|
||||
{ }
|
||||
|
||||
|
||||
bool base::requires_fallback(math_expression const & expression)
|
||||
bool base::requires_fallback(expression_tree const & expression)
|
||||
{
|
||||
for(math_expression::node const & node: expression.tree())
|
||||
for(expression_tree::node const & node: expression.tree())
|
||||
if( (node.lhs.subtype==DENSE_ARRAY_TYPE && (node.lhs.array->stride()[0]>1 || node.lhs.array->start()>0))
|
||||
|| (node.rhs.subtype==DENSE_ARRAY_TYPE && (node.rhs.array->stride()[0]>1 || node.rhs.array->start()>0)))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
int_t base::vector_size(math_expression::node const & node)
|
||||
int_t base::vector_size(expression_tree::node const & node)
|
||||
{
|
||||
if (node.op.type==MATRIX_DIAG_TYPE)
|
||||
return std::min<int_t>(node.lhs.array->shape()[0], node.lhs.array->shape()[1]);
|
||||
@@ -50,7 +50,7 @@ int_t base::vector_size(math_expression::node const & node)
|
||||
|
||||
}
|
||||
|
||||
std::pair<int_t, int_t> base::matrix_size(math_expression::container_type const & tree, math_expression::node const & node)
|
||||
std::pair<int_t, int_t> base::matrix_size(expression_tree::container_type const & tree, expression_tree::node const & node)
|
||||
{
|
||||
if (node.op.type==VDIAG_TYPE)
|
||||
{
|
||||
@@ -72,20 +72,20 @@ std::pair<int_t, int_t> base::matrix_size(math_expression::container_type const
|
||||
base::base(binding_policy_t binding_policy) : binding_policy_(binding_policy)
|
||||
{}
|
||||
|
||||
unsigned int base::lmem_usage(math_expression const &) const
|
||||
unsigned int base::lmem_usage(expression_tree const &) const
|
||||
{ return 0; }
|
||||
|
||||
unsigned int base::registers_usage(math_expression const &) const
|
||||
unsigned int base::registers_usage(expression_tree const &) const
|
||||
{ return 0; }
|
||||
|
||||
unsigned int base::temporary_workspace(math_expression const &) const
|
||||
unsigned int base::temporary_workspace(expression_tree const &) const
|
||||
{ return 0; }
|
||||
|
||||
base::~base()
|
||||
{
|
||||
}
|
||||
|
||||
std::string base::generate(std::string const & suffix, math_expression const & expression, driver::Device const & device)
|
||||
std::string base::generate(std::string const & suffix, expression_tree const & expression, driver::Device const & device)
|
||||
{
|
||||
int err = is_invalid(expression, device);
|
||||
if(err != 0)
|
||||
@@ -104,7 +104,7 @@ std::string base::generate(std::string const & suffix, math_expression const &
|
||||
}
|
||||
|
||||
template<class TType, class PType>
|
||||
int base_impl<TType, PType>::is_invalid_impl(driver::Device const &, math_expression const &) const
|
||||
int base_impl<TType, PType>::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{ return TEMPLATE_VALID; }
|
||||
|
||||
template<class TType, class PType>
|
||||
@@ -124,7 +124,7 @@ std::shared_ptr<base> base_impl<TType, PType>::clone() const
|
||||
{ return std::shared_ptr<base>(new TType(*dynamic_cast<TType const *>(this))); }
|
||||
|
||||
template<class TType, class PType>
|
||||
int base_impl<TType, PType>::is_invalid(math_expression const & expressions, driver::Device const & device) const
|
||||
int base_impl<TType, PType>::is_invalid(expression_tree const & expressions, driver::Device const & device) const
|
||||
{
|
||||
//Query device informations
|
||||
size_t lmem_available = device.local_mem_size();
|
||||
|
@@ -26,14 +26,14 @@ elementwise_1d_parameters::elementwise_1d_parameters(unsigned int _simd_width,
|
||||
}
|
||||
|
||||
|
||||
int elementwise_1d::is_invalid_impl(driver::Device const &, math_expression const &) const
|
||||
int elementwise_1d::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{
|
||||
if (p_.fetching_policy==FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
|
||||
std::string elementwise_1d::generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mappings) const
|
||||
std::string elementwise_1d::generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const & mappings) const
|
||||
{
|
||||
driver::backend_type backend = device.backend();
|
||||
std::string _size_t = size_type(device);
|
||||
@@ -43,7 +43,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, math_expre
|
||||
std::string dtype = append_width("#scalartype",p_.simd_width);
|
||||
|
||||
|
||||
std::vector<size_t> assigned_scalar = filter_nodes([](math_expression::node const & node) {
|
||||
std::vector<size_t> assigned_scalar = filter_nodes([](expression_tree::node const & node) {
|
||||
return detail::is_assignment(node.op) && node.lhs.subtype==DENSE_ARRAY_TYPE && node.lhs.array->shape().max()==1;
|
||||
}, expressions, expressions.root(), true);
|
||||
|
||||
@@ -74,8 +74,8 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, math_expre
|
||||
stream.inc_tab();
|
||||
|
||||
|
||||
math_expression::container_type const & tree = expressions.tree();
|
||||
std::vector<std::size_t> sfors = filter_nodes([](math_expression::node const & node){return node.op.type==SFOR_TYPE;}, expressions, expressions.root(), true);
|
||||
expression_tree::container_type const & tree = expressions.tree();
|
||||
std::vector<std::size_t> sfors = filter_nodes([](expression_tree::node const & node){return node.op.type==SFOR_TYPE;}, expressions, expressions.root(), true);
|
||||
|
||||
for(unsigned int i = 0 ; i < sfors.size() ; ++i)
|
||||
{
|
||||
@@ -100,7 +100,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, math_expre
|
||||
if(sfors.size())
|
||||
root = tree[sfors.back()].lhs.node_index;
|
||||
|
||||
std::vector<std::size_t> assigned = filter_nodes([](math_expression::node const & node){return detail::is_assignment(node.op);}, expressions, root, true);
|
||||
std::vector<std::size_t> assigned = filter_nodes([](expression_tree::node const & node){return detail::is_assignment(node.op);}, expressions, root, true);
|
||||
std::set<std::string> processed;
|
||||
|
||||
//Declares register to store results
|
||||
@@ -182,14 +182,14 @@ elementwise_1d::elementwise_1d(unsigned int simd, unsigned int ls, unsigned int
|
||||
{}
|
||||
|
||||
|
||||
std::vector<int_t> elementwise_1d::input_sizes(math_expression const & expressions) const
|
||||
std::vector<int_t> elementwise_1d::input_sizes(expression_tree const & expressions) const
|
||||
{
|
||||
return {expressions.shape().max()};
|
||||
}
|
||||
|
||||
void elementwise_1d::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const & control)
|
||||
{
|
||||
math_expression const & expressions = control.x();
|
||||
expression_tree const & expressions = control.x();
|
||||
//Size
|
||||
int_t size = input_sizes(expressions)[0];
|
||||
//Fallback
|
||||
|
@@ -20,7 +20,7 @@ elementwise_2d_parameters::elementwise_2d_parameters(unsigned int _simd_width,
|
||||
|
||||
|
||||
|
||||
int elementwise_2d::is_invalid_impl(driver::Device const &, math_expression const &) const
|
||||
int elementwise_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{
|
||||
if (p_.simd_width>1)
|
||||
return TEMPLATE_INVALID_SIMD_WIDTH;
|
||||
@@ -29,7 +29,7 @@ int elementwise_2d::is_invalid_impl(driver::Device const &, math_expression cons
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
|
||||
std::string elementwise_2d::generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mappings) const
|
||||
std::string elementwise_2d::generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const & mappings) const
|
||||
{
|
||||
kernel_generation_stream stream;
|
||||
std::string _size_t = size_type(device);
|
||||
@@ -114,7 +114,7 @@ elementwise_2d::elementwise_2d(unsigned int simd, unsigned int ls1, unsigned int
|
||||
base_impl<elementwise_2d, elementwise_2d_parameters>(elementwise_2d_parameters(simd, ls1, ls2, ng1, ng2, fetch), bind)
|
||||
{}
|
||||
|
||||
std::vector<int_t> elementwise_2d::input_sizes(math_expression const & expression) const
|
||||
std::vector<int_t> elementwise_2d::input_sizes(expression_tree const & expression) const
|
||||
{
|
||||
std::pair<int_t, int_t> size = matrix_size(expression.tree(), lhs_most(expression.tree(), expression.root()));
|
||||
return {size.first, size.second};
|
||||
@@ -122,7 +122,7 @@ std::vector<int_t> elementwise_2d::input_sizes(math_expression const & expressi
|
||||
|
||||
void elementwise_2d::enqueue(driver::CommandQueue & /*queue*/, driver::Program const & program, std::string const & suffix, base &, execution_handler const & control)
|
||||
{
|
||||
math_expression const & expressions = control.x();
|
||||
expression_tree const & expressions = control.x();
|
||||
std::string name = "elementwise_1d";
|
||||
name +=suffix;
|
||||
driver::Kernel kernel(program, name.c_str());
|
||||
|
@@ -27,7 +27,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
}
|
||||
|
||||
|
||||
unsigned int matrix_product::lmem_usage(math_expression const & expression) const
|
||||
unsigned int matrix_product::lmem_usage(expression_tree const & expression) const
|
||||
{
|
||||
numeric_type numeric_t = lhs_most(expression.tree(), expression.root()).lhs.dtype;
|
||||
unsigned int N = 0;
|
||||
@@ -36,7 +36,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
return N*size_of(numeric_t);
|
||||
}
|
||||
|
||||
unsigned int matrix_product::registers_usage(math_expression const & expression) const
|
||||
unsigned int matrix_product::registers_usage(expression_tree const & expression) const
|
||||
{
|
||||
numeric_type numeric_t = lhs_most(expression.tree(), expression.root()).lhs.dtype;
|
||||
|
||||
@@ -44,7 +44,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
return N*size_of(numeric_t);
|
||||
}
|
||||
|
||||
unsigned int matrix_product::temporary_workspace(math_expression const & expressions) const
|
||||
unsigned int matrix_product::temporary_workspace(expression_tree const & expressions) const
|
||||
{
|
||||
std::vector<int_t> MNK = input_sizes(expressions);
|
||||
int_t M = MNK[0]; int_t N = MNK[1];
|
||||
@@ -53,7 +53,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
return 0;
|
||||
}
|
||||
|
||||
int matrix_product::is_invalid_impl(driver::Device const &, math_expression const &) const
|
||||
int matrix_product::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{
|
||||
// if(device.vendor()==driver::Device::Vendor::NVIDIA && p_.simd_width > 1)
|
||||
// return TEMPLATE_INVALID_SIMD_WIDTH;
|
||||
@@ -103,7 +103,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
|
||||
std::string matrix_product::generate_impl(std::string const & suffix, math_expression const & expression, driver::Device const & device, mapping_type const &) const
|
||||
std::string matrix_product::generate_impl(std::string const & suffix, expression_tree const & expression, driver::Device const & device, mapping_type const &) const
|
||||
{
|
||||
using std::string;
|
||||
using tools::to_string;
|
||||
@@ -651,9 +651,9 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
|
||||
}
|
||||
|
||||
std::vector<int_t> matrix_product::infos(math_expression const & expression, symbolic::preset::matrix_product::args& arguments) const
|
||||
std::vector<int_t> matrix_product::infos(expression_tree const & expression, symbolic::preset::matrix_product::args& arguments) const
|
||||
{
|
||||
math_expression::container_type const & array = expression.tree();
|
||||
expression_tree::container_type const & array = expression.tree();
|
||||
std::size_t root = expression.root();
|
||||
arguments = symbolic::preset::matrix_product::check(array, root);
|
||||
int_t M = arguments.C->array->shape()[0];
|
||||
@@ -671,10 +671,10 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
else throw;
|
||||
}
|
||||
|
||||
std::vector<int_t> matrix_product::input_sizes(math_expression const & expressions) const
|
||||
std::vector<int_t> matrix_product::input_sizes(expression_tree const & expressions) const
|
||||
{
|
||||
symbolic::preset::matrix_product::args dummy;
|
||||
return infos((math_expression&)expressions, dummy);
|
||||
return infos((expression_tree&)expressions, dummy);
|
||||
}
|
||||
|
||||
void matrix_product::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback_base, execution_handler const & control)
|
||||
@@ -682,7 +682,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
using namespace tools;
|
||||
|
||||
matrix_product & fallback = (matrix_product&)fallback_base;
|
||||
math_expression const & expressions = control.x();
|
||||
expression_tree const & expressions = control.x();
|
||||
|
||||
|
||||
symbolic::preset::matrix_product::args args;
|
||||
|
@@ -20,20 +20,20 @@ reduce_1d_parameters::reduce_1d_parameters(unsigned int _simd_width,
|
||||
fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _group_size, 1, 2), num_groups(_num_groups), fetching_policy(_fetching_policy)
|
||||
{ }
|
||||
|
||||
unsigned int reduce_1d::lmem_usage(math_expression const & x) const
|
||||
unsigned int reduce_1d::lmem_usage(expression_tree const & x) const
|
||||
{
|
||||
numeric_type numeric_t= lhs_most(x.tree(), x.root()).lhs.dtype;
|
||||
return p_.local_size_0*size_of(numeric_t);
|
||||
}
|
||||
|
||||
int reduce_1d::is_invalid_impl(driver::Device const &, math_expression const &) const
|
||||
int reduce_1d::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{
|
||||
if (p_.fetching_policy==FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
|
||||
unsigned int reduce_1d::temporary_workspace(math_expression const &) const
|
||||
unsigned int reduce_1d::temporary_workspace(expression_tree const &) const
|
||||
{
|
||||
if(p_.num_groups > 1)
|
||||
return p_.num_groups;
|
||||
@@ -65,7 +65,7 @@ inline void reduce_1d::reduce_1d_local_memory(kernel_generation_stream & stream,
|
||||
stream << "}" << std::endl;
|
||||
}
|
||||
|
||||
std::string reduce_1d::generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mapping) const
|
||||
std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const & mapping) const
|
||||
{
|
||||
kernel_generation_stream stream;
|
||||
|
||||
@@ -87,7 +87,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, math_expression
|
||||
unsigned int offset = 0;
|
||||
for (unsigned int k = 0; k < N; ++k)
|
||||
{
|
||||
numeric_type dtype = lhs_most(exprs[k]->math_expression().tree(), exprs[k]->math_expression().root()).lhs.dtype;
|
||||
numeric_type dtype = lhs_most(exprs[k]->expression_tree().tree(), exprs[k]->expression_tree().root()).lhs.dtype;
|
||||
std::string sdtype = to_string(dtype);
|
||||
if (exprs[k]->is_index_reduction())
|
||||
{
|
||||
@@ -298,7 +298,7 @@ reduce_1d::reduce_1d(unsigned int simd, unsigned int ls, unsigned int ng,
|
||||
base_impl<reduce_1d, reduce_1d_parameters>(reduce_1d_parameters(simd,ls,ng,fetch), bind)
|
||||
{}
|
||||
|
||||
std::vector<int_t> reduce_1d::input_sizes(math_expression const & x) const
|
||||
std::vector<int_t> reduce_1d::input_sizes(expression_tree const & x) const
|
||||
{
|
||||
std::vector<size_t> reduce_1ds_idx = filter_nodes(&is_reduce_1d, x, x.root(), false);
|
||||
int_t N = vector_size(lhs_most(x.tree(), reduce_1ds_idx[0]));
|
||||
@@ -307,7 +307,7 @@ std::vector<int_t> reduce_1d::input_sizes(math_expression const & x) const
|
||||
|
||||
void reduce_1d::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const & control)
|
||||
{
|
||||
math_expression const & x = control.x();
|
||||
expression_tree const & x = control.x();
|
||||
|
||||
//Preprocessing
|
||||
int_t size = input_sizes(x)[0];
|
||||
@@ -319,7 +319,7 @@ void reduce_1d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<math_expression::node const *> reduce_1ds;
|
||||
std::vector<expression_tree::node const *> reduce_1ds;
|
||||
std::vector<size_t> reduce_1ds_idx = filter_nodes(&is_reduce_1d, x, x.root(), false);
|
||||
for (size_t idx: reduce_1ds_idx)
|
||||
reduce_1ds.push_back(&x.tree()[idx]);
|
||||
|
@@ -22,19 +22,19 @@ reduce_2d_parameters::reduce_2d_parameters(unsigned int _simd_width,
|
||||
num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetch_policy(_fetch_policy) { }
|
||||
|
||||
|
||||
int reduce_2d::is_invalid_impl(driver::Device const &, math_expression const &) const
|
||||
int reduce_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{
|
||||
if (p_.fetch_policy==FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
|
||||
unsigned int reduce_2d::lmem_usage(const math_expression&) const
|
||||
unsigned int reduce_2d::lmem_usage(const expression_tree&) const
|
||||
{
|
||||
return (p_.local_size_0+1)*p_.local_size_1;
|
||||
}
|
||||
|
||||
unsigned int reduce_2d::temporary_workspace(math_expression const & expressions) const
|
||||
unsigned int reduce_2d::temporary_workspace(expression_tree const & expressions) const
|
||||
{
|
||||
std::vector<int_t> MN = input_sizes(expressions);
|
||||
int_t M = MN[0];
|
||||
@@ -43,7 +43,7 @@ unsigned int reduce_2d::temporary_workspace(math_expression const & expressions)
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string reduce_2d::generate_impl(std::string const & suffix, math_expression const & expression, driver::Device const & device, mapping_type const & mapping) const
|
||||
std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree const & expression, driver::Device const & device, mapping_type const & mapping) const
|
||||
{
|
||||
using tools::to_string;
|
||||
|
||||
@@ -67,7 +67,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, math_expression
|
||||
unsigned int offset = 0;
|
||||
for (const auto & e : reduce_1ds)
|
||||
{
|
||||
numeric_type dtype = lhs_most(e->math_expression().tree(), e->math_expression().root()).lhs.dtype;
|
||||
numeric_type dtype = lhs_most(e->expression_tree().tree(), e->expression_tree().root()).lhs.dtype;
|
||||
std::string sdtype = to_string(dtype);
|
||||
if (e->is_index_reduction())
|
||||
{
|
||||
@@ -360,7 +360,7 @@ reduce_2d::reduce_2d(reduce_2d::parameters_type const & parameters,
|
||||
base_impl<reduce_2d, reduce_2d_parameters>(parameters, binding_policy),
|
||||
reduce_1d_type_(rtype){ }
|
||||
|
||||
std::vector<int_t> reduce_2d::input_sizes(math_expression const & expression) const
|
||||
std::vector<int_t> reduce_2d::input_sizes(expression_tree const & expression) const
|
||||
{
|
||||
std::vector<std::size_t> idx = filter_nodes(&is_reduce_1d, expression, expression.root(), false);
|
||||
std::pair<int_t, int_t> MN = matrix_size(expression.tree(), lhs_most(expression.tree(), idx[0]));
|
||||
@@ -371,10 +371,10 @@ std::vector<int_t> reduce_2d::input_sizes(math_expression const & expression) co
|
||||
|
||||
void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const & control)
|
||||
{
|
||||
math_expression const & expression = control.x();
|
||||
expression_tree const & expression = control.x();
|
||||
|
||||
std::vector<int_t> MN = input_sizes(expression);
|
||||
std::vector<math_expression::node const *> reduce_1ds;
|
||||
std::vector<expression_tree::node const *> reduce_1ds;
|
||||
std::vector<size_t> reduce_1ds_idx = filter_nodes(&is_reduce_1d, expression, expression.root(), false);
|
||||
for (size_t idx : reduce_1ds_idx)
|
||||
reduce_1ds.push_back(&expression.tree()[idx]);
|
||||
|
@@ -12,7 +12,7 @@ namespace templates
|
||||
{
|
||||
|
||||
//Generate
|
||||
inline std::string generate_arguments(std::string const &, driver::Device const & device, mapping_type const & mappings, math_expression const & expressions)
|
||||
inline std::string generate_arguments(std::string const &, driver::Device const & device, mapping_type const & mappings, expression_tree const & expressions)
|
||||
{
|
||||
std::string kwglobal = Global(device.backend()).get();
|
||||
std::string _size_t = size_type(device);
|
||||
@@ -90,9 +90,9 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t leaf_t) const
|
||||
void operator()(isaac::expression_tree const & expression_tree, size_t root_idx, leaf_t leaf_t) const
|
||||
{
|
||||
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
||||
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
|
||||
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
|
||||
set_arguments(root_node.lhs, detail::is_assignment(root_node.op));
|
||||
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.subtype != COMPOSITE_OPERATOR_TYPE)
|
||||
@@ -106,7 +106,7 @@ private:
|
||||
driver::Kernel & kernel_;
|
||||
};
|
||||
|
||||
inline void set_arguments(math_expression const & expression, driver::Kernel & kernel, unsigned int & current_arg, binding_policy_t binding_policy)
|
||||
inline void set_arguments(expression_tree const & expression, driver::Kernel & kernel, unsigned int & current_arg, binding_policy_t binding_policy)
|
||||
{
|
||||
std::unique_ptr<symbolic_binder> binder;
|
||||
if (binding_policy==BIND_SEQUENTIAL)
|
||||
|
@@ -12,18 +12,18 @@ namespace templates
|
||||
class map_functor : public traversal_functor
|
||||
{
|
||||
|
||||
numeric_type get_numeric_type(isaac::math_expression const * math_expression, size_t root_idx) const
|
||||
numeric_type get_numeric_type(isaac::expression_tree const * expression_tree, size_t root_idx) const
|
||||
{
|
||||
math_expression::node const * root_node = &math_expression->tree()[root_idx];
|
||||
expression_tree::node const * root_node = &expression_tree->tree()[root_idx];
|
||||
while (root_node->lhs.dtype==INVALID_NUMERIC_TYPE)
|
||||
root_node = &math_expression->tree()[root_node->lhs.node_index];
|
||||
root_node = &expression_tree->tree()[root_node->lhs.node_index];
|
||||
return root_node->lhs.dtype;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
std::shared_ptr<mapped_object> binary_leaf(isaac::math_expression const * math_expression, size_t root_idx, mapping_type const * mapping) const
|
||||
std::shared_ptr<mapped_object> binary_leaf(isaac::expression_tree const * expression_tree, size_t root_idx, mapping_type const * mapping) const
|
||||
{
|
||||
return std::shared_ptr<mapped_object>(new T(to_string(math_expression->dtype()), binder_.get(), mapped_object::node_info(mapping, math_expression, root_idx)));
|
||||
return std::shared_ptr<mapped_object>(new T(to_string(expression_tree->dtype()), binder_.get(), mapped_object::node_info(mapping, expression_tree, root_idx)));
|
||||
}
|
||||
|
||||
std::shared_ptr<mapped_object> create(numeric_type dtype, values_holder) const
|
||||
@@ -60,10 +60,10 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t leaf_t) const
|
||||
void operator()(isaac::expression_tree const & expression_tree, size_t root_idx, leaf_t leaf_t) const
|
||||
{
|
||||
mapping_type::key_type key(root_idx, leaf_t);
|
||||
math_expression::node const & root_node = math_expression.tree()[root_idx];
|
||||
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
|
||||
|
||||
if (leaf_t == LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, create(root_node.lhs, detail::is_assignment(root_node.op))));
|
||||
@@ -72,25 +72,25 @@ public:
|
||||
else if ( leaf_t== PARENT_NODE_TYPE)
|
||||
{
|
||||
if (root_node.op.type==VDIAG_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_vdiag>(&math_expression, root_idx, &mapping_)));
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_vdiag>(&expression_tree, root_idx, &mapping_)));
|
||||
else if (root_node.op.type==MATRIX_DIAG_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_diag>(&math_expression, root_idx, &mapping_)));
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_diag>(&expression_tree, root_idx, &mapping_)));
|
||||
else if (root_node.op.type==MATRIX_ROW_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&math_expression, root_idx, &mapping_)));
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&expression_tree, root_idx, &mapping_)));
|
||||
else if (root_node.op.type==MATRIX_COLUMN_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&math_expression, root_idx, &mapping_)));
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&expression_tree, root_idx, &mapping_)));
|
||||
else if(root_node.op.type==ACCESS_INDEX_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_array_access>(&math_expression, root_idx, &mapping_)));
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_array_access>(&expression_tree, root_idx, &mapping_)));
|
||||
else if (detail::is_scalar_reduce_1d(root_node))
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_1d>(&math_expression, root_idx, &mapping_)));
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_1d>(&expression_tree, root_idx, &mapping_)));
|
||||
else if (detail::is_vector_reduce_1d(root_node))
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_2d>(&math_expression, root_idx, &mapping_)));
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_2d>(&expression_tree, root_idx, &mapping_)));
|
||||
else if (root_node.op.type_family == MATRIX_PRODUCT_TYPE_FAMILY)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_product>(&math_expression, root_idx, &mapping_)));
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_product>(&expression_tree, root_idx, &mapping_)));
|
||||
else if (root_node.op.type == REPEAT_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&math_expression, root_idx, &mapping_)));
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&expression_tree, root_idx, &mapping_)));
|
||||
else if (root_node.op.type == OUTER_PROD_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&math_expression, root_idx, &mapping_)));
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&expression_tree, root_idx, &mapping_)));
|
||||
else if (detail::is_cast(root_node.op))
|
||||
mapping_.insert(mapping_type::value_type(key, std::shared_ptr<mapped_object>(new mapped_cast(root_node.op.type, binder_.get()))));
|
||||
}
|
||||
|
@@ -55,7 +55,7 @@ inline std::string neutral_element(op_element const & op, driver::backend_type b
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_reduce_1d(math_expression::node const & node)
|
||||
inline bool is_reduce_1d(expression_tree::node const & node)
|
||||
{
|
||||
return node.op.type_family==VECTOR_DOT_TYPE_FAMILY
|
||||
|| node.op.type_family==COLUMNS_DOT_TYPE_FAMILY
|
||||
|
@@ -39,7 +39,7 @@ driver::Program const & profiles::value_type::init(execution_handler const & exp
|
||||
|
||||
char* ptr = program_name;
|
||||
bind_independent binder;
|
||||
traverse(expression.x(), expression.x().root(), math_expression_representation_functor(binder, ptr),true);
|
||||
traverse(expression.x(), expression.x().root(), expression_tree_representation_functor(binder, ptr),true);
|
||||
*ptr='\0';
|
||||
pname = std::string(program_name);
|
||||
|
||||
|
@@ -105,12 +105,12 @@ namespace isaac
|
||||
}
|
||||
|
||||
/** @brief Parses the breakpoints for a given expression tree */
|
||||
static void parse(math_expression::container_type&array, size_t idx,
|
||||
static void parse(expression_tree::container_type&array, size_t idx,
|
||||
breakpoints_t & breakpoints,
|
||||
expression_type & final_type,
|
||||
bool is_first = true)
|
||||
{
|
||||
math_expression::node & node = array[idx];
|
||||
expression_tree::node & node = array[idx];
|
||||
|
||||
auto ng1 = [](shape_t const & shape){ size_t res = 0 ; for(size_t i = 0 ; i < shape.size() ; ++i) res += (shape[i] > 1); return res;};
|
||||
//Left
|
||||
@@ -146,14 +146,14 @@ namespace isaac
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief Executes a math_expression on the given models map*/
|
||||
/** @brief Executes a expression_tree on the given models map*/
|
||||
void execute(execution_handler const & c, profiles::map_type & profiles)
|
||||
{
|
||||
math_expression expression = c.x();
|
||||
expression_tree expression = c.x();
|
||||
driver::Context const & context = expression.context();
|
||||
size_t rootidx = expression.root();
|
||||
math_expression::container_type & tree = const_cast<math_expression::container_type &>(expression.tree());
|
||||
math_expression::node root_save = tree[rootidx];
|
||||
expression_tree::container_type & tree = const_cast<expression_tree::container_type &>(expression.tree());
|
||||
expression_tree::node root_save = tree[rootidx];
|
||||
|
||||
//Todo: technically the datatype should be per temporary
|
||||
numeric_type dtype = expression.dtype();
|
||||
@@ -186,8 +186,8 @@ namespace isaac
|
||||
for(detail::breakpoints_t::iterator it = breakpoints.begin() ; it != breakpoints.end() ; ++it)
|
||||
{
|
||||
std::shared_ptr<profiles::value_type> const & profile = profiles[std::make_pair(it->first, dtype)];
|
||||
math_expression::node const & node = tree[it->second->node_index];
|
||||
math_expression::node const & lmost = lhs_most(tree, node);
|
||||
expression_tree::node const & node = tree[it->second->node_index];
|
||||
expression_tree::node const & lmost = lhs_most(tree, node);
|
||||
|
||||
//Creates temporary
|
||||
std::shared_ptr<array> tmp;
|
||||
@@ -217,7 +217,7 @@ namespace isaac
|
||||
profile->execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
||||
tree[rootidx] = root_save;
|
||||
|
||||
//Incorporates the temporary within, the math_expression
|
||||
//Incorporates the temporary within, the expression_tree
|
||||
fill(*it->second, (array&)*tmp);
|
||||
}
|
||||
}
|
||||
|
@@ -49,7 +49,7 @@ op_element::op_element() {}
|
||||
op_element::op_element(operation_type_family const & _type_family, operation_type const & _type) : type_family(_type_family), type(_type){}
|
||||
|
||||
//
|
||||
math_expression::math_expression(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op)
|
||||
expression_tree::expression_tree(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op)
|
||||
: tree_(1), root_(0), context_(NULL), dtype_(INVALID_NUMERIC_TYPE), shape_(1)
|
||||
{
|
||||
fill(tree_[0].lhs, lhs);
|
||||
@@ -57,7 +57,7 @@ math_expression::math_expression(for_idx_t const &lhs, for_idx_t const &rhs, con
|
||||
fill(tree_[0].rhs, rhs);
|
||||
}
|
||||
|
||||
math_expression::math_expression(for_idx_t const &lhs, value_scalar const &rhs, const op_element &op, const numeric_type &dtype)
|
||||
expression_tree::expression_tree(for_idx_t const &lhs, value_scalar const &rhs, const op_element &op, const numeric_type &dtype)
|
||||
: tree_(1), root_(0), context_(NULL), dtype_(dtype), shape_(1)
|
||||
{
|
||||
fill(tree_[0].lhs, lhs);
|
||||
@@ -65,7 +65,7 @@ math_expression::math_expression(for_idx_t const &lhs, value_scalar const &rhs,
|
||||
fill(tree_[0].rhs, rhs);
|
||||
}
|
||||
|
||||
math_expression::math_expression(value_scalar const &lhs, for_idx_t const &rhs, const op_element &op, const numeric_type &dtype)
|
||||
expression_tree::expression_tree(value_scalar const &lhs, for_idx_t const &rhs, const op_element &op, const numeric_type &dtype)
|
||||
: tree_(1), root_(0), context_(NULL), dtype_(dtype), shape_(1)
|
||||
{
|
||||
fill(tree_[0].lhs, lhs);
|
||||
@@ -75,11 +75,11 @@ math_expression::math_expression(value_scalar const &lhs, for_idx_t const &rhs,
|
||||
|
||||
|
||||
|
||||
//math_expression(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op);
|
||||
//math_expression(for_idx_t const &lhs, value_scalar const &rhs, const op_element &op, const numeric_type &dtype);
|
||||
//expression_tree(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op);
|
||||
//expression_tree(for_idx_t const &lhs, value_scalar const &rhs, const op_element &op, const numeric_type &dtype);
|
||||
|
||||
template<class LT, class RT>
|
||||
math_expression::math_expression(LT const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape) :
|
||||
expression_tree::expression_tree(LT const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape) :
|
||||
tree_(1), root_(0), context_(&context), dtype_(dtype), shape_(shape)
|
||||
{
|
||||
fill(tree_[0].lhs, lhs);
|
||||
@@ -88,7 +88,7 @@ math_expression::math_expression(LT const & lhs, RT const & rhs, op_element cons
|
||||
}
|
||||
|
||||
template<class RT>
|
||||
math_expression::math_expression(math_expression const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape) :
|
||||
expression_tree::expression_tree(expression_tree const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape) :
|
||||
tree_(lhs.tree_.size() + 1), root_(tree_.size()-1), context_(&context), dtype_(dtype), shape_(shape)
|
||||
{
|
||||
std::copy(lhs.tree_.begin(), lhs.tree_.end(), tree_.begin());
|
||||
@@ -98,7 +98,7 @@ math_expression::math_expression(math_expression const & lhs, RT const & rhs, op
|
||||
}
|
||||
|
||||
template<class LT>
|
||||
math_expression::math_expression(LT const & lhs, math_expression const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape) :
|
||||
expression_tree::expression_tree(LT const & lhs, expression_tree const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape) :
|
||||
tree_(rhs.tree_.size() + 1), root_(tree_.size() - 1), context_(&context), dtype_(dtype), shape_(shape)
|
||||
{
|
||||
std::copy(rhs.tree_.begin(), rhs.tree_.end(), tree_.begin());
|
||||
@@ -107,7 +107,7 @@ math_expression::math_expression(LT const & lhs, math_expression const & rhs, op
|
||||
fill(tree_[root_].rhs, rhs.root_);
|
||||
}
|
||||
|
||||
math_expression::math_expression(math_expression const & lhs, math_expression const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape):
|
||||
expression_tree::expression_tree(expression_tree const & lhs, expression_tree const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape):
|
||||
tree_(lhs.tree_.size() + rhs.tree_.size() + 1), root_(tree_.size()-1), context_(&context), dtype_(dtype), shape_(shape)
|
||||
{
|
||||
std::size_t lsize = lhs.tree_.size();
|
||||
@@ -123,84 +123,77 @@ math_expression::math_expression(math_expression const & lhs, math_expression co
|
||||
root_ = tree_.size() - 1;
|
||||
}
|
||||
|
||||
template math_expression::math_expression(math_expression const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template math_expression::math_expression(math_expression const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template math_expression::math_expression(math_expression const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template math_expression::math_expression(math_expression const &, for_idx_t const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(expression_tree const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(expression_tree const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(expression_tree const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(expression_tree const &, for_idx_t const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
|
||||
template math_expression::math_expression(value_scalar const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template math_expression::math_expression(value_scalar const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template math_expression::math_expression(value_scalar const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template math_expression::math_expression(value_scalar const &, math_expression const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(value_scalar const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(value_scalar const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(value_scalar const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(value_scalar const &, expression_tree const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
|
||||
template math_expression::math_expression(invalid_node const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template math_expression::math_expression(invalid_node const &, math_expression const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template math_expression::math_expression(invalid_node const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template math_expression::math_expression(invalid_node const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(invalid_node const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(invalid_node const &, expression_tree const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(invalid_node const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(invalid_node const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
|
||||
template math_expression::math_expression(array_base const &, math_expression const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template math_expression::math_expression(array_base const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template math_expression::math_expression(array_base const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template math_expression::math_expression(array_base const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template math_expression::math_expression(array_base const &, for_idx_t const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(array_base const &, expression_tree const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(array_base const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(array_base const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(array_base const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(array_base const &, for_idx_t const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
|
||||
template math_expression::math_expression(for_idx_t const &, math_expression const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template math_expression::math_expression(for_idx_t const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(for_idx_t const &, expression_tree const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
template expression_tree::expression_tree(for_idx_t const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
|
||||
|
||||
math_expression::container_type & math_expression::tree()
|
||||
expression_tree::container_type & expression_tree::tree()
|
||||
{ return tree_; }
|
||||
|
||||
math_expression::container_type const & math_expression::tree() const
|
||||
expression_tree::container_type const & expression_tree::tree() const
|
||||
{ return tree_; }
|
||||
|
||||
std::size_t math_expression::root() const
|
||||
std::size_t expression_tree::root() const
|
||||
{ return root_; }
|
||||
|
||||
driver::Context const & math_expression::context() const
|
||||
driver::Context const & expression_tree::context() const
|
||||
{ return *context_; }
|
||||
|
||||
numeric_type const & math_expression::dtype() const
|
||||
numeric_type const & expression_tree::dtype() const
|
||||
{ return dtype_; }
|
||||
|
||||
shape_t math_expression::shape() const
|
||||
shape_t expression_tree::shape() const
|
||||
{ return shape_; }
|
||||
|
||||
int_t math_expression::dim() const
|
||||
int_t expression_tree::dim() const
|
||||
{ return (int_t)shape_.size(); }
|
||||
|
||||
//math_expression& math_expression::reshape(int_t size1, int_t size2)
|
||||
//{
|
||||
// assert(size1*size2==prod(shape_));
|
||||
// shape_ = size4(size1, size2);
|
||||
// return *this;
|
||||
//}
|
||||
expression_tree expression_tree::operator-()
|
||||
{ return expression_tree(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), *context_, dtype_, shape_); }
|
||||
|
||||
math_expression math_expression::operator-()
|
||||
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), *context_, dtype_, shape_); }
|
||||
|
||||
math_expression math_expression::operator!()
|
||||
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), *context_, INT_TYPE, shape_); }
|
||||
expression_tree expression_tree::operator!()
|
||||
{ return expression_tree(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), *context_, INT_TYPE, shape_); }
|
||||
|
||||
//
|
||||
|
||||
math_expression::node const & lhs_most(math_expression::container_type const & array, math_expression::node const & init)
|
||||
expression_tree::node const & lhs_most(expression_tree::container_type const & array, expression_tree::node const & init)
|
||||
{
|
||||
math_expression::node const * current = &init;
|
||||
expression_tree::node const * current = &init;
|
||||
while (current->lhs.subtype==COMPOSITE_OPERATOR_TYPE)
|
||||
current = &array[current->lhs.node_index];
|
||||
return *current;
|
||||
}
|
||||
|
||||
math_expression::node const & lhs_most(math_expression::container_type const & array, size_t root)
|
||||
expression_tree::node const & lhs_most(expression_tree::container_type const & array, size_t root)
|
||||
{ return lhs_most(array, array[root]); }
|
||||
|
||||
//
|
||||
math_expression for_idx_t::operator=(value_scalar const & r) const { return math_expression(*this, r, op_element(BINARY_TYPE_FAMILY,ASSIGN_TYPE), r.dtype()); }
|
||||
math_expression for_idx_t::operator=(math_expression const & r) const { return math_expression(*this, r, op_element(BINARY_TYPE_FAMILY,ASSIGN_TYPE), r.context(), r.dtype(), r.shape()); }
|
||||
expression_tree for_idx_t::operator=(value_scalar const & r) const { return expression_tree(*this, r, op_element(BINARY_TYPE_FAMILY,ASSIGN_TYPE), r.dtype()); }
|
||||
expression_tree for_idx_t::operator=(expression_tree const & r) const { return expression_tree(*this, r, op_element(BINARY_TYPE_FAMILY,ASSIGN_TYPE), r.context(), r.dtype(), r.shape()); }
|
||||
|
||||
math_expression for_idx_t::operator+=(value_scalar const & r) const { return *this = *this + r; }
|
||||
math_expression for_idx_t::operator-=(value_scalar const & r) const { return *this = *this - r; }
|
||||
math_expression for_idx_t::operator*=(value_scalar const & r) const { return *this = *this * r; }
|
||||
math_expression for_idx_t::operator/=(value_scalar const & r) const { return *this = *this / r; }
|
||||
expression_tree for_idx_t::operator+=(value_scalar const & r) const { return *this = *this + r; }
|
||||
expression_tree for_idx_t::operator-=(value_scalar const & r) const { return *this = *this - r; }
|
||||
expression_tree for_idx_t::operator*=(value_scalar const & r) const { return *this = *this * r; }
|
||||
expression_tree for_idx_t::operator/=(value_scalar const & r) const { return *this = *this / r; }
|
||||
|
||||
}
|
||||
|
@@ -30,7 +30,7 @@ inline std::string to_string(tree_node const & e)
|
||||
return tools::to_string(e.subtype);
|
||||
}
|
||||
|
||||
inline std::ostream & operator<<(std::ostream & os, math_expression::node const & s_node)
|
||||
inline std::ostream & operator<<(std::ostream & os, expression_tree::node const & s_node)
|
||||
{
|
||||
os << "LHS: " << to_string(s_node.lhs) << "|" << s_node.lhs.dtype << ", "
|
||||
<< "OP: " << s_node.op.type_family << " | " << s_node.op.type << ", "
|
||||
@@ -42,11 +42,11 @@ inline std::ostream & operator<<(std::ostream & os, math_expression::node const
|
||||
|
||||
namespace detail
|
||||
{
|
||||
/** @brief Recursive worker routine for printing a whole math_expression */
|
||||
inline void print_node(std::ostream & os, isaac::math_expression const & s, size_t node_index, size_t indent = 0)
|
||||
/** @brief Recursive worker routine for printing a whole expression_tree */
|
||||
inline void print_node(std::ostream & os, isaac::expression_tree const & s, size_t node_index, size_t indent = 0)
|
||||
{
|
||||
math_expression::container_type const & nodes = s.tree();
|
||||
math_expression::node const & current_node = nodes[node_index];
|
||||
expression_tree::container_type const & nodes = s.tree();
|
||||
expression_tree::node const & current_node = nodes[node_index];
|
||||
|
||||
for (size_t i=0; i<indent; ++i)
|
||||
os << " ";
|
||||
@@ -61,7 +61,7 @@ namespace detail
|
||||
}
|
||||
}
|
||||
|
||||
std::string to_string(isaac::math_expression const & s)
|
||||
std::string to_string(isaac::expression_tree const & s)
|
||||
{
|
||||
std::ostringstream os;
|
||||
detail::print_node(os, s, s.root());
|
||||
|
@@ -9,7 +9,7 @@ namespace symbolic
|
||||
namespace preset
|
||||
{
|
||||
|
||||
void matrix_product::handle_node(math_expression::container_type const & tree, size_t rootidx, args & a)
|
||||
void matrix_product::handle_node(expression_tree::container_type const & tree, size_t rootidx, args & a)
|
||||
{
|
||||
//Matrix-Matrix product node
|
||||
if(tree[rootidx].op.type_family==MATRIX_PRODUCT_TYPE_FAMILY)
|
||||
@@ -46,7 +46,7 @@ void matrix_product::handle_node(math_expression::container_type const & tree, s
|
||||
}
|
||||
}
|
||||
|
||||
matrix_product::args matrix_product::check(math_expression::container_type const & tree, size_t rootidx)
|
||||
matrix_product::args matrix_product::check(expression_tree::container_type const & tree, size_t rootidx)
|
||||
{
|
||||
tree_node const * assigned = &tree[rootidx].lhs;
|
||||
numeric_type dtype = assigned->dtype;
|
||||
|
@@ -50,7 +50,7 @@ INSTANTIATE(double)
|
||||
|
||||
value_scalar::value_scalar(numeric_type dtype) : dtype_(dtype) {}
|
||||
value_scalar::value_scalar(scalar const & s) : dtype_(s.dtype()) { s.inject(values_); }
|
||||
value_scalar::value_scalar(math_expression const &expr) : dtype_(expr.dtype()) { scalar(expr).inject(values_); }
|
||||
value_scalar::value_scalar(expression_tree const &expr) : dtype_(expr.dtype()) { scalar(expr).inject(values_); }
|
||||
value_scalar::value_scalar(values_holder values, numeric_type dtype) : values_(values), dtype_(dtype) {}
|
||||
values_holder value_scalar::values() const
|
||||
{ return values_; }
|
||||
|
@@ -18,7 +18,7 @@ extern "C"
|
||||
isaac::driver::backend::release();
|
||||
}
|
||||
|
||||
void execute(sc::math_expression const & operation, sc::driver::Context const & context,
|
||||
void execute(sc::expression_tree const & operation, sc::driver::Context const & context,
|
||||
cl_uint numCommandQueues, cl_command_queue *commandQueues,
|
||||
cl_uint numEventsInWaitList, const cl_event *eventWaitList,
|
||||
cl_event *events)
|
||||
|
@@ -252,15 +252,15 @@ void export_core()
|
||||
#undef INSTANTIATE
|
||||
|
||||
bp::enum_<sc::expression_type>("operations")
|
||||
MAP_ENUM(AXPY_TYPE, sc)
|
||||
MAP_ENUM(GER_TYPE, sc)
|
||||
MAP_ENUM(DOT_TYPE, sc)
|
||||
MAP_ENUM(GEMV_N_TYPE, sc)
|
||||
MAP_ENUM(GEMV_T_TYPE, sc)
|
||||
MAP_ENUM(GEMM_NN_TYPE, sc)
|
||||
MAP_ENUM(GEMM_TN_TYPE, sc)
|
||||
MAP_ENUM(GEMM_NT_TYPE, sc)
|
||||
MAP_ENUM(GEMM_TT_TYPE, sc);
|
||||
MAP_ENUM(ELEMENTWISE_1D, sc)
|
||||
MAP_ENUM(ELEMENTWISE_2D, sc)
|
||||
MAP_ENUM(REDUCE_1D, sc)
|
||||
MAP_ENUM(REDUCE_2D_ROWS, sc)
|
||||
MAP_ENUM(REDUCE_2D_COLS, sc)
|
||||
MAP_ENUM(MATRIX_PRODUCT_NN, sc)
|
||||
MAP_ENUM(MATRIX_PRODUCT_TN, sc)
|
||||
MAP_ENUM(MATRIX_PRODUCT_NT, sc)
|
||||
MAP_ENUM(MATRIX_PRODUCT_TT, sc);
|
||||
|
||||
#define ADD_SCALAR_HANDLING(OP)\
|
||||
.def(bp::self OP int())\
|
||||
@@ -276,7 +276,7 @@ void export_core()
|
||||
.def(bp::self OP bp::self)\
|
||||
ADD_SCALAR_HANDLING(OP)
|
||||
|
||||
bp::class_<sc::math_expression >("math_expression", bp::no_init)
|
||||
bp::class_<sc::expression_tree >("expression_tree", bp::no_init)
|
||||
ADD_ARRAY_OPERATOR(+)
|
||||
ADD_ARRAY_OPERATOR(-)
|
||||
ADD_ARRAY_OPERATOR(*)
|
||||
@@ -287,7 +287,7 @@ void export_core()
|
||||
ADD_ARRAY_OPERATOR(<=)
|
||||
ADD_ARRAY_OPERATOR(==)
|
||||
ADD_ARRAY_OPERATOR(!=)
|
||||
.add_property("context", bp::make_function(&sc::math_expression::context, bp::return_internal_reference<>()))
|
||||
.add_property("context", bp::make_function(&sc::expression_tree::context, bp::return_internal_reference<>()))
|
||||
.def(bp::self_ns::abs(bp::self))
|
||||
// .def(bp::self_ns::pow(bp::self))
|
||||
;
|
||||
@@ -295,8 +295,8 @@ void export_core()
|
||||
|
||||
#define ADD_ARRAY_OPERATOR(OP) \
|
||||
.def(bp::self OP bp::self)\
|
||||
.def(bp::self OP bp::other<sc::math_expression>())\
|
||||
.def(bp::other<sc::math_expression>() OP bp::self) \
|
||||
.def(bp::self OP bp::other<sc::expression_tree>())\
|
||||
.def(bp::other<sc::expression_tree>() OP bp::self) \
|
||||
ADD_SCALAR_HANDLING(OP)
|
||||
|
||||
bp::class_<sc::array_base, boost::noncopyable>("array_base", bp::no_init)
|
||||
@@ -322,7 +322,7 @@ void export_core()
|
||||
bp::class_<sc::array,std::shared_ptr<sc::array>, bp::bases<sc::array_base> >
|
||||
( "array", bp::no_init)
|
||||
.def("__init__", bp::make_constructor(detail::create_array, bp::default_call_policies(), (bp::arg("obj"), bp::arg("dtype") = bp::scope().attr("float32"), bp::arg("context")= bp::object())))
|
||||
.def(bp::init<sc::math_expression>())
|
||||
.def(bp::init<sc::expression_tree>())
|
||||
;
|
||||
|
||||
bp::class_<sc::view, bp::bases<sc::array_base> >
|
||||
@@ -338,15 +338,15 @@ void export_core()
|
||||
bp::def("empty", &detail::create_empty_array, (bp::arg("shape"), bp::arg("dtype") = bp::scope().attr("float32"), bp::arg("context")=bp::object()));
|
||||
|
||||
//Assign
|
||||
bp::def("assign", static_cast<sc::math_expression (*)(sc::array_base const &, sc::array_base const &)>(&sc::assign));\
|
||||
bp::def("assign", static_cast<sc::math_expression (*)(sc::array_base const &, sc::math_expression const &)>(&sc::assign));\
|
||||
bp::def("assign", static_cast<sc::expression_tree (*)(sc::array_base const &, sc::array_base const &)>(&sc::assign));\
|
||||
bp::def("assign", static_cast<sc::expression_tree (*)(sc::array_base const &, sc::expression_tree const &)>(&sc::assign));\
|
||||
|
||||
//Binary
|
||||
#define MAP_FUNCTION(name) \
|
||||
bp::def(#name, static_cast<sc::math_expression (*)(sc::array_base const &, sc::array_base const &)>(&sc::name));\
|
||||
bp::def(#name, static_cast<sc::math_expression (*)(sc::math_expression const &, sc::array_base const &)>(&sc::name));\
|
||||
bp::def(#name, static_cast<sc::math_expression (*)(sc::array_base const &, sc::math_expression const &)>(&sc::name));\
|
||||
bp::def(#name, static_cast<sc::math_expression (*)(sc::math_expression const &, sc::math_expression const &)>(&sc::name));
|
||||
bp::def(#name, static_cast<sc::expression_tree (*)(sc::array_base const &, sc::array_base const &)>(&sc::name));\
|
||||
bp::def(#name, static_cast<sc::expression_tree (*)(sc::expression_tree const &, sc::array_base const &)>(&sc::name));\
|
||||
bp::def(#name, static_cast<sc::expression_tree (*)(sc::array_base const &, sc::expression_tree const &)>(&sc::name));\
|
||||
bp::def(#name, static_cast<sc::expression_tree (*)(sc::expression_tree const &, sc::expression_tree const &)>(&sc::name));
|
||||
|
||||
MAP_FUNCTION(maximum)
|
||||
MAP_FUNCTION(minimum)
|
||||
@@ -356,8 +356,8 @@ void export_core()
|
||||
|
||||
//Unary
|
||||
#define MAP_FUNCTION(name) \
|
||||
bp::def(#name, static_cast<sc::math_expression (*)(sc::array_base const &)>(&sc::name));\
|
||||
bp::def(#name, static_cast<sc::math_expression (*)(sc::math_expression const &)>(&sc::name));
|
||||
bp::def(#name, static_cast<sc::expression_tree (*)(sc::array_base const &)>(&sc::name));\
|
||||
bp::def(#name, static_cast<sc::expression_tree (*)(sc::expression_tree const &)>(&sc::name));
|
||||
|
||||
bp::def("zeros", &detail::create_zeros_array, (bp::arg("shape"), bp::arg("dtype") = bp::scope().attr("float32"), bp::arg("context")=bp::object()));
|
||||
|
||||
@@ -382,8 +382,8 @@ void export_core()
|
||||
/*--- Reduction operators----*/
|
||||
//---------------------------------------
|
||||
#define MAP_FUNCTION(name) \
|
||||
bp::def(#name, static_cast<sc::math_expression (*)(sc::array_base const &, sc::int_t)>(&sc::name));\
|
||||
bp::def(#name, static_cast<sc::math_expression (*)(sc::math_expression const &, sc::int_t)>(&sc::name));
|
||||
bp::def(#name, static_cast<sc::expression_tree (*)(sc::array_base const &, sc::int_t)>(&sc::name));\
|
||||
bp::def(#name, static_cast<sc::expression_tree (*)(sc::expression_tree const &, sc::int_t)>(&sc::name));
|
||||
|
||||
MAP_FUNCTION(sum)
|
||||
MAP_FUNCTION(max)
|
||||
|
@@ -62,7 +62,7 @@ namespace detail
|
||||
std::shared_ptr<sc::driver::Context> make_context(sc::driver::Device const & dev)
|
||||
{ return std::shared_ptr<sc::driver::Context>(new sc::driver::Context(dev)); }
|
||||
|
||||
bp::object enqueue(sc::math_expression const & expression, unsigned int queue_id, bp::list dependencies, bool tune, int label, std::string const & program_name, bool force_recompile)
|
||||
bp::object enqueue(sc::expression_tree const & expression, unsigned int queue_id, bp::list dependencies, bool tune, int label, std::string const & program_name, bool force_recompile)
|
||||
{
|
||||
std::list<sc::driver::Event> events;
|
||||
std::vector<sc::driver::Event> cdependencies = tools::to_vector<sc::driver::Event>(dependencies);
|
||||
@@ -70,7 +70,7 @@ namespace detail
|
||||
sc::execution_options_type execution_options(queue_id, &events, &cdependencies);
|
||||
sc::dispatcher_options_type dispatcher_options(tune, label);
|
||||
sc::compilation_options_type compilation_options(program_name, force_recompile);
|
||||
sc::math_expression::container_type::value_type root = expression.tree()[expression.root()];
|
||||
sc::expression_tree::container_type::value_type root = expression.tree()[expression.root()];
|
||||
if(sc::detail::is_assignment(root.op))
|
||||
{
|
||||
sc::execute(sc::execution_handler(expression, execution_options, dispatcher_options, compilation_options), isaac::profiles::get(execution_options.queue(expression.context())));
|
||||
|
@@ -13,7 +13,7 @@ namespace tpt = isaac::templates;
|
||||
|
||||
namespace detail
|
||||
{
|
||||
bp::list input_sizes(tpt::base & temp, sc::math_expression const & tree)
|
||||
bp::list input_sizes(tpt::base & temp, sc::expression_tree const & tree)
|
||||
{
|
||||
std::vector<isaac::int_t> tmp = temp.input_sizes(tree);
|
||||
return tools::to_list(tmp.begin(), tmp.end());
|
||||
|
Reference in New Issue
Block a user