Code quality: renamed math_expression -> expression_tree

This commit is contained in:
Philippe Tillet
2015-12-19 02:55:24 -05:00
parent 5a035b91a2
commit d9eb51d04a
37 changed files with 557 additions and 565 deletions

View File

@@ -59,30 +59,30 @@ public:
//Numeric operators
array_base& operator=(array_base const &);
array_base& operator=(math_expression const &);
array_base& operator=(expression_tree const &);
array_base& operator=(execution_handler const &);
template<class T>
array_base & operator=(std::vector<T> const & rhs);
array_base & operator=(value_scalar const & rhs);
math_expression operator-();
math_expression operator!();
expression_tree operator-();
expression_tree operator!();
array_base& operator+=(value_scalar const &);
array_base& operator+=(array_base const &);
array_base& operator+=(math_expression const &);
array_base& operator+=(expression_tree const &);
array_base& operator-=(value_scalar const &);
array_base& operator-=(array_base const &);
array_base& operator-=(math_expression const &);
array_base& operator-=(expression_tree const &);
array_base& operator*=(value_scalar const &);
array_base& operator*=(array_base const &);
array_base& operator*=(math_expression const &);
array_base& operator*=(expression_tree const &);
array_base& operator/=(value_scalar const &);
array_base& operator/=(array_base const &);
array_base& operator/=(math_expression const &);
array_base& operator/=(expression_tree const &);
//Indexing (1D)
math_expression operator[](for_idx_t idx) const;
expression_tree operator[](for_idx_t idx) const;
const scalar operator[](int_t) const;
scalar operator[](int_t);
view operator[](slice const &);
@@ -105,7 +105,7 @@ protected:
driver::Buffer data_;
public:
math_expression T;
expression_tree T;
};
class ISAACAPI array : public array_base
@@ -115,7 +115,7 @@ public:
//Copy Constructor
array(array_base const &);
array(array const &);
array(math_expression const & proxy);
array(expression_tree const & proxy);
using array_base::operator=;
};
@@ -132,7 +132,7 @@ public:
class ISAACAPI scalar : public array_base
{
friend value_scalar::value_scalar(const scalar &);
friend value_scalar::value_scalar(const math_expression &);
friend value_scalar::value_scalar(const expression_tree &);
private:
void inject(values_holder&) const;
template<class T> T cast() const;
@@ -140,7 +140,7 @@ public:
explicit scalar(numeric_type dtype, const driver::Buffer &data, int_t offset);
explicit scalar(value_scalar value, driver::Context const & context = driver::backend::contexts::get_default());
explicit scalar(numeric_type dtype, driver::Context const & context = driver::backend::contexts::get_default());
scalar(math_expression const & proxy);
scalar(expression_tree const & proxy);
scalar& operator=(value_scalar const &);
// scalar& operator=(scalar const & s);
using array_base::operator =;
@@ -178,24 +178,24 @@ template<class T> ISAACAPI void copy(array_base const & gA, std::vector<T> & cA,
//Binary operators
#define ISAAC_DECLARE_ELEMENT_BINARY_OPERATOR(OPNAME) \
ISAACAPI math_expression OPNAME (array_base const & x, math_expression const & y);\
ISAACAPI math_expression OPNAME (array_base const & x, value_scalar const & y);\
ISAACAPI math_expression OPNAME (array_base const & x, for_idx_t const & y);\
ISAACAPI math_expression OPNAME (array_base const & x, array_base const & y);\
ISAACAPI expression_tree OPNAME (array_base const & x, expression_tree const & y);\
ISAACAPI expression_tree OPNAME (array_base const & x, value_scalar const & y);\
ISAACAPI expression_tree OPNAME (array_base const & x, for_idx_t const & y);\
ISAACAPI expression_tree OPNAME (array_base const & x, array_base const & y);\
\
ISAACAPI math_expression OPNAME (math_expression const & x, math_expression const & y);\
ISAACAPI math_expression OPNAME (math_expression const & x, value_scalar const & y);\
ISAACAPI math_expression OPNAME (math_expression const & x, for_idx_t const & y);\
ISAACAPI math_expression OPNAME (math_expression const & x, array_base const & y);\
ISAACAPI expression_tree OPNAME (expression_tree const & x, expression_tree const & y);\
ISAACAPI expression_tree OPNAME (expression_tree const & x, value_scalar const & y);\
ISAACAPI expression_tree OPNAME (expression_tree const & x, for_idx_t const & y);\
ISAACAPI expression_tree OPNAME (expression_tree const & x, array_base const & y);\
\
ISAACAPI math_expression OPNAME (value_scalar const & y, math_expression const & x);\
ISAACAPI math_expression OPNAME (value_scalar const & y, for_idx_t const & x);\
ISAACAPI math_expression OPNAME (value_scalar const & y, array_base const & x);\
ISAACAPI expression_tree OPNAME (value_scalar const & y, expression_tree const & x);\
ISAACAPI expression_tree OPNAME (value_scalar const & y, for_idx_t const & x);\
ISAACAPI expression_tree OPNAME (value_scalar const & y, array_base const & x);\
\
ISAACAPI math_expression OPNAME (for_idx_t const & y, math_expression const & x);\
ISAACAPI math_expression OPNAME (for_idx_t const & y, for_idx_t const & x);\
ISAACAPI math_expression OPNAME (for_idx_t const & y, value_scalar const & x);\
ISAACAPI math_expression OPNAME (for_idx_t const & y, array_base const & x);
ISAACAPI expression_tree OPNAME (for_idx_t const & y, expression_tree const & x);\
ISAACAPI expression_tree OPNAME (for_idx_t const & y, for_idx_t const & x);\
ISAACAPI expression_tree OPNAME (for_idx_t const & y, value_scalar const & x);\
ISAACAPI expression_tree OPNAME (for_idx_t const & y, array_base const & x);
ISAAC_DECLARE_ELEMENT_BINARY_OPERATOR(operator +)
ISAAC_DECLARE_ELEMENT_BINARY_OPERATOR(operator -)
@@ -221,29 +221,29 @@ ISAAC_DECLARE_ELEMENT_BINARY_OPERATOR(assign)
#undef ISAAC_DECLARE_ELEMENT_BINARY_OPERATOR
#define ISAAC_DECLARE_ROT(LTYPE, RTYPE, CTYPE, STYPE) \
math_expression rot(LTYPE const & x, RTYPE const & y, CTYPE const & c, STYPE const & s);
expression_tree rot(LTYPE const & x, RTYPE const & y, CTYPE const & c, STYPE const & s);
ISAAC_DECLARE_ROT(array_base, array_base, scalar, scalar)
ISAAC_DECLARE_ROT(math_expression, array_base, scalar, scalar)
ISAAC_DECLARE_ROT(array_base, math_expression, scalar, scalar)
ISAAC_DECLARE_ROT(math_expression, math_expression, scalar, scalar)
ISAAC_DECLARE_ROT(expression_tree, array_base, scalar, scalar)
ISAAC_DECLARE_ROT(array_base, expression_tree, scalar, scalar)
ISAAC_DECLARE_ROT(expression_tree, expression_tree, scalar, scalar)
ISAAC_DECLARE_ROT(array_base, array_base, value_scalar, value_scalar)
ISAAC_DECLARE_ROT(math_expression, array_base, value_scalar, value_scalar)
ISAAC_DECLARE_ROT(array_base, math_expression, value_scalar, value_scalar)
ISAAC_DECLARE_ROT(math_expression, math_expression, value_scalar, value_scalar)
ISAAC_DECLARE_ROT(expression_tree, array_base, value_scalar, value_scalar)
ISAAC_DECLARE_ROT(array_base, expression_tree, value_scalar, value_scalar)
ISAAC_DECLARE_ROT(expression_tree, expression_tree, value_scalar, value_scalar)
ISAAC_DECLARE_ROT(array_base, array_base, math_expression, math_expression)
ISAAC_DECLARE_ROT(math_expression, array_base, math_expression, math_expression)
ISAAC_DECLARE_ROT(array_base, math_expression, math_expression, math_expression)
ISAAC_DECLARE_ROT(math_expression, math_expression, math_expression, math_expression)
ISAAC_DECLARE_ROT(array_base, array_base, expression_tree, expression_tree)
ISAAC_DECLARE_ROT(expression_tree, array_base, expression_tree, expression_tree)
ISAAC_DECLARE_ROT(array_base, expression_tree, expression_tree, expression_tree)
ISAAC_DECLARE_ROT(expression_tree, expression_tree, expression_tree, expression_tree)
//--------------------------------
//Unary operators
#define ISAAC_DECLARE_UNARY_OPERATOR(OPNAME) \
ISAACAPI math_expression OPNAME (array_base const & x);\
ISAACAPI math_expression OPNAME (math_expression const & x);
ISAACAPI expression_tree OPNAME (array_base const & x);\
ISAACAPI expression_tree OPNAME (expression_tree const & x);
ISAAC_DECLARE_UNARY_OPERATOR(abs)
ISAAC_DECLARE_UNARY_OPERATOR(acos)
@@ -263,21 +263,21 @@ ISAAC_DECLARE_UNARY_OPERATOR(tan)
ISAAC_DECLARE_UNARY_OPERATOR(tanh)
ISAAC_DECLARE_UNARY_OPERATOR(trans)
ISAACAPI math_expression cast(array_base const &, numeric_type dtype);
ISAACAPI math_expression cast(math_expression const &, numeric_type dtype);
ISAACAPI expression_tree cast(array_base const &, numeric_type dtype);
ISAACAPI expression_tree cast(expression_tree const &, numeric_type dtype);
ISAACAPI math_expression norm(array_base const &, unsigned int order = 2);
ISAACAPI math_expression norm(math_expression const &, unsigned int order = 2);
ISAACAPI expression_tree norm(array_base const &, unsigned int order = 2);
ISAACAPI expression_tree norm(expression_tree const &, unsigned int order = 2);
#undef ISAAC_DECLARE_UNARY_OPERATOR
ISAACAPI math_expression repmat(array_base const &, int_t const & rep1, int_t const & rep2);
ISAACAPI expression_tree repmat(array_base const &, int_t const & rep1, int_t const & rep2);
//Matrix reduction
#define ISAAC_DECLARE_DOT(OPNAME) \
ISAACAPI math_expression OPNAME(array_base const & M, int_t axis = -1);\
ISAACAPI math_expression OPNAME(math_expression const & M, int_t axis = -1);
ISAACAPI expression_tree OPNAME(array_base const & M, int_t axis = -1);\
ISAACAPI expression_tree OPNAME(expression_tree const & M, int_t axis = -1);
ISAAC_DECLARE_DOT(sum)
ISAAC_DECLARE_DOT(argmax)
@@ -286,10 +286,10 @@ ISAAC_DECLARE_DOT((min))
ISAAC_DECLARE_DOT(argmin)
//Fusion
ISAACAPI math_expression fuse(math_expression const & x, math_expression const & y);
ISAACAPI expression_tree fuse(expression_tree const & x, expression_tree const & y);
//For
ISAACAPI math_expression sfor(math_expression const & start, math_expression const & end, math_expression const & inc, math_expression const & expression);
ISAACAPI expression_tree sfor(expression_tree const & start, expression_tree const & end, expression_tree const & inc, expression_tree const & expression);
static const for_idx_t _i0{0};
static const for_idx_t _i1{1};
static const for_idx_t _i2{2};
@@ -302,41 +302,41 @@ static const for_idx_t _i8{8};
static const for_idx_t _i9{9};
//Initializers
ISAACAPI math_expression eye(int_t, int_t, isaac::numeric_type, driver::Context const & context = driver::backend::contexts::get_default());
ISAACAPI math_expression zeros(int_t M, int_t N, numeric_type dtype, driver::Context const & context = driver::backend::contexts::get_default());
ISAACAPI expression_tree eye(int_t, int_t, isaac::numeric_type, driver::Context const & context = driver::backend::contexts::get_default());
ISAACAPI expression_tree zeros(int_t M, int_t N, numeric_type dtype, driver::Context const & context = driver::backend::contexts::get_default());
//Swap
ISAACAPI void swap(view x, view y);
//Reshape
ISAACAPI math_expression reshape(array_base const &, shape_t const &);
ISAACAPI math_expression ravel(array_base const &);
ISAACAPI expression_tree reshape(array_base const &, shape_t const &);
ISAACAPI expression_tree ravel(array_base const &);
//diag
array diag(array_base & x, int offset = 0);
//Row
ISAACAPI math_expression row(array_base const &, value_scalar const &);
ISAACAPI math_expression row(array_base const &, for_idx_t const &);
ISAACAPI math_expression row(array_base const &, math_expression const &);
ISAACAPI expression_tree row(array_base const &, value_scalar const &);
ISAACAPI expression_tree row(array_base const &, for_idx_t const &);
ISAACAPI expression_tree row(array_base const &, expression_tree const &);
ISAACAPI math_expression row(math_expression const &, value_scalar const &);
ISAACAPI math_expression row(math_expression const &, for_idx_t const &);
ISAACAPI math_expression row(math_expression const &, math_expression const &);
ISAACAPI expression_tree row(expression_tree const &, value_scalar const &);
ISAACAPI expression_tree row(expression_tree const &, for_idx_t const &);
ISAACAPI expression_tree row(expression_tree const &, expression_tree const &);
//col
ISAACAPI math_expression col(array_base const &, value_scalar const &);
ISAACAPI math_expression col(array_base const &, for_idx_t const &);
ISAACAPI math_expression col(array_base const &, math_expression const &);
ISAACAPI expression_tree col(array_base const &, value_scalar const &);
ISAACAPI expression_tree col(array_base const &, for_idx_t const &);
ISAACAPI expression_tree col(array_base const &, expression_tree const &);
ISAACAPI math_expression col(math_expression const &, value_scalar const &);
ISAACAPI math_expression col(math_expression const &, for_idx_t const &);
ISAACAPI math_expression col(math_expression const &, math_expression const &);
ISAACAPI expression_tree col(expression_tree const &, value_scalar const &);
ISAACAPI expression_tree col(expression_tree const &, for_idx_t const &);
ISAACAPI expression_tree col(expression_tree const &, expression_tree const &);
//
ISAACAPI std::ostream& operator<<(std::ostream &, array_base const &);
ISAACAPI std::ostream& operator<<(std::ostream &, math_expression const &);
ISAACAPI std::ostream& operator<<(std::ostream &, expression_tree const &);
}
#endif

View File

@@ -27,7 +27,7 @@ typedef std::map<mapping_key, std::shared_ptr<mapped_object> > mapping_type;
/** @brief Mapped Object
*
* This object populates the symbolic mapping associated with a math_expression. (root_id, LHS|RHS|PARENT) => mapped_object
* This object populates the symbolic mapping associated with a expression_tree. (root_id, LHS|RHS|PARENT) => mapped_object
* The tree can then be reconstructed in its symbolic form
*/
class mapped_object
@@ -49,9 +49,9 @@ protected:
public:
struct node_info
{
node_info(mapping_type const * _mapping, math_expression const * _math_expression, size_t _root_idx);
node_info(mapping_type const * _mapping, expression_tree const * _expression_tree, size_t _root_idx);
mapping_type const * mapping;
isaac::math_expression const * math_expression;
isaac::expression_tree const * expression_tree;
size_t root_idx;
};
@@ -105,8 +105,8 @@ public:
mapped_reduce(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key);
size_t root_idx() const;
isaac::math_expression const & math_expression() const;
math_expression::node root_node() const;
isaac::expression_tree const & expression_tree() const;
expression_tree::node root_node() const;
bool is_index_reduction() const;
op_element root_op() const;
};
@@ -253,7 +253,7 @@ public:
mapped_cast(operation_type type, unsigned int id);
};
extern mapped_object& get(math_expression::container_type const &, size_t, mapping_type const &, size_t);
extern mapped_object& get(expression_tree::container_type const &, size_t, mapping_type const &, size_t);
}
#endif

View File

@@ -13,8 +13,8 @@ namespace detail
{
bool is_node_leaf(op_element const & op);
bool is_scalar_reduce_1d(math_expression::node const & node);
bool is_vector_reduce_1d(math_expression::node const & node);
bool is_scalar_reduce_1d(expression_tree::node const & node);
bool is_vector_reduce_1d(expression_tree::node const & node);
bool is_assignment(op_element const & op);
bool is_elementwise_operator(op_element const & op);
bool is_elementwise_function(op_element const & op);
@@ -24,58 +24,58 @@ namespace detail
class scalar;
/** @brief base functor class for traversing a math_expression */
/** @brief base functor class for traversing a expression_tree */
class traversal_functor
{
public:
void call_before_expansion(math_expression const &, std::size_t) const { }
void call_after_expansion(math_expression const &, std::size_t) const { }
void call_before_expansion(expression_tree const &, std::size_t) const { }
void call_after_expansion(expression_tree const &, std::size_t) const { }
};
/** @brief Recursively execute a functor on a math_expression */
/** @brief Recursively execute a functor on a expression_tree */
template<class Fun>
inline void traverse(isaac::math_expression const & math_expression, std::size_t root_idx, Fun const & fun, bool inspect)
inline void traverse(isaac::expression_tree const & expression_tree, std::size_t root_idx, Fun const & fun, bool inspect)
{
math_expression::node const & root_node = math_expression.tree()[root_idx];
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
bool recurse = detail::is_node_leaf(root_node.op)?inspect:true;
bool bypass = detail::bypass(root_node.op);
if(!bypass)
fun.call_before_expansion(math_expression, root_idx);
fun.call_before_expansion(expression_tree, root_idx);
//Lhs:
if (recurse)
{
if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.lhs.node_index, fun, inspect);
traverse(expression_tree, root_node.lhs.node_index, fun, inspect);
if (root_node.lhs.subtype != INVALID_SUBTYPE)
fun(math_expression, root_idx, LHS_NODE_TYPE);
fun(expression_tree, root_idx, LHS_NODE_TYPE);
}
//Self:
if(!bypass)
fun(math_expression, root_idx, PARENT_NODE_TYPE);
fun(expression_tree, root_idx, PARENT_NODE_TYPE);
//Rhs:
if (recurse && root_node.rhs.subtype!=INVALID_SUBTYPE)
{
if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.rhs.node_index, fun, inspect);
traverse(expression_tree, root_node.rhs.node_index, fun, inspect);
if (root_node.rhs.subtype != INVALID_SUBTYPE)
fun(math_expression, root_idx, RHS_NODE_TYPE);
fun(expression_tree, root_idx, RHS_NODE_TYPE);
}
if(!bypass)
fun.call_after_expansion(math_expression, root_idx);
fun.call_after_expansion(expression_tree, root_idx);
}
class filter_fun : public traversal_functor
{
public:
typedef bool (*pred_t)(math_expression::node const & node);
typedef bool (*pred_t)(expression_tree::node const & node);
filter_fun(pred_t pred, std::vector<size_t> & out);
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t) const;
void operator()(isaac::expression_tree const & expression_tree, size_t root_idx, leaf_t) const;
private:
pred_t pred_;
std::vector<size_t> & out_;
@@ -85,22 +85,22 @@ class filter_elements_fun : public traversal_functor
{
public:
filter_elements_fun(node_type subtype, std::vector<tree_node> & out);
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t) const;
void operator()(isaac::expression_tree const & expression_tree, size_t root_idx, leaf_t) const;
private:
node_type subtype_;
std::vector<tree_node> & out_;
};
std::vector<size_t> filter_nodes(bool (*pred)(math_expression::node const & node),
isaac::math_expression const & math_expression,
std::vector<size_t> filter_nodes(bool (*pred)(expression_tree::node const & node),
isaac::expression_tree const & expression_tree,
size_t root,
bool inspect);
std::vector<tree_node> filter_elements(node_type subtype,
isaac::math_expression const & math_expression);
isaac::expression_tree const & expression_tree);
const char * evaluate(operation_type type);
/** @brief functor for generating the expression string from a math_expression */
/** @brief functor for generating the expression string from a expression_tree */
class evaluate_expression_traversal: public traversal_functor
{
private:
@@ -110,24 +110,24 @@ private:
public:
evaluate_expression_traversal(std::map<std::string, std::string> const & accessors, std::string & str, mapping_type const & mapping);
void call_before_expansion(isaac::math_expression const & math_expression, std::size_t root_idx) const;
void call_after_expansion(math_expression const & /*math_expression*/, std::size_t /*root_idx*/) const;
void operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf) const;
void call_before_expansion(isaac::expression_tree const & expression_tree, std::size_t root_idx) const;
void call_after_expansion(expression_tree const & /*expression_tree*/, std::size_t /*root_idx*/) const;
void operator()(isaac::expression_tree const & expression_tree, std::size_t root_idx, leaf_t leaf) const;
};
std::string evaluate(leaf_t leaf, std::map<std::string, std::string> const & accessors,
isaac::math_expression const & math_expression, std::size_t root_idx, mapping_type const & mapping);
isaac::expression_tree const & expression_tree, std::size_t root_idx, mapping_type const & mapping);
void evaluate(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
math_expression const & expressions, mapping_type const & mappings);
expression_tree const & expressions, mapping_type const & mappings);
/** @brief functor for fetching or writing-back the elements in a math_expression */
/** @brief functor for fetching or writing-back the elements in a expression_tree */
class process_traversal : public traversal_functor
{
public:
process_traversal(std::map<std::string, std::string> const & accessors, kernel_generation_stream & stream,
mapping_type const & mapping, std::set<std::string> & already_processed);
void operator()(math_expression const & math_expression, std::size_t root_idx, leaf_t leaf) const;
void operator()(expression_tree const & expression_tree, std::size_t root_idx, leaf_t leaf) const;
private:
std::map<std::string, std::string> accessors_;
kernel_generation_stream & stream_;
@@ -136,21 +136,21 @@ private:
};
void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
isaac::math_expression const & math_expression, size_t root_idx, mapping_type const & mapping, std::set<std::string> & already_processed);
isaac::expression_tree const & expression_tree, size_t root_idx, mapping_type const & mapping, std::set<std::string> & already_processed);
void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
math_expression const & expressions, mapping_type const & mappings);
expression_tree const & expressions, mapping_type const & mappings);
class math_expression_representation_functor : public traversal_functor{
class expression_tree_representation_functor : public traversal_functor{
private:
static void append_id(char * & ptr, unsigned int val);
void append(driver::Buffer const & h, numeric_type dtype, char prefix) const;
void append(tree_node const & lhs_rhs, bool is_assigned) const;
public:
math_expression_representation_functor(symbolic_binder & binder, char *& ptr);
expression_tree_representation_functor(symbolic_binder & binder, char *& ptr);
void append(char*& p, const char * str) const;
void operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf_t) const;
void operator()(isaac::expression_tree const & expression_tree, std::size_t root_idx, leaf_t leaf_t) const;
private:
symbolic_binder & binder_;
char *& ptr_;

View File

@@ -60,20 +60,20 @@ public:
unsigned int num_kernels;
};
protected:
static int_t vector_size(math_expression::node const & node);
static std::pair<int_t, int_t> matrix_size(math_expression::container_type const & tree, math_expression::node const & node);
static bool requires_fallback(math_expression const & expressions);
static int_t vector_size(expression_tree::node const & node);
static std::pair<int_t, int_t> matrix_size(expression_tree::container_type const & tree, expression_tree::node const & node);
static bool requires_fallback(expression_tree const & expressions);
private:
virtual std::string generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mapping) const = 0;
virtual std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const & mapping) const = 0;
public:
base(binding_policy_t binding_policy);
virtual unsigned int temporary_workspace(math_expression const &) const;
virtual unsigned int lmem_usage(math_expression const &) const;
virtual unsigned int registers_usage(math_expression const &) const;
virtual std::vector<int_t> input_sizes(math_expression const & expressions) const = 0;
virtual unsigned int temporary_workspace(expression_tree const &) const;
virtual unsigned int lmem_usage(expression_tree const &) const;
virtual unsigned int registers_usage(expression_tree const &) const;
virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const = 0;
virtual ~base();
std::string generate(std::string const & suffix, math_expression const & expressions, driver::Device const & device);
virtual int is_invalid(math_expression const & expressions, driver::Device const & device) const = 0;
std::string generate(std::string const & suffix, expression_tree const & expressions, driver::Device const & device);
virtual int is_invalid(expression_tree const & expressions, driver::Device const & device) const = 0;
virtual void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const & expressions) = 0;
virtual std::shared_ptr<base> clone() const = 0;
private:
@@ -85,7 +85,7 @@ template<class TemplateType, class ParametersType>
class base_impl : public base
{
private:
virtual int is_invalid_impl(driver::Device const &, math_expression const &) const;
virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const;
public:
typedef ParametersType parameters_type;
base_impl(parameters_type const & parameters, binding_policy_t binding_policy);
@@ -93,7 +93,7 @@ public:
unsigned int local_size_1() const;
std::shared_ptr<base> clone() const;
/** @brief returns whether or not the profile has undefined behavior on particular device */
int is_invalid(math_expression const & expressions, driver::Device const & device) const;
int is_invalid(expression_tree const & expressions, driver::Device const & device) const;
protected:
parameters_type p_;
binding_policy_t binding_policy_;

View File

@@ -19,12 +19,12 @@ public:
class elementwise_1d : public base_impl<elementwise_1d, elementwise_1d_parameters>
{
private:
virtual int is_invalid_impl(driver::Device const &, math_expression const &) const;
std::string generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mappings) const;
virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const;
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const & mappings) const;
public:
elementwise_1d(elementwise_1d::parameters_type const & parameters, binding_policy_t binding_policy = BIND_INDEPENDENT);
elementwise_1d(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy, binding_policy_t binding_policy = BIND_INDEPENDENT);
std::vector<int_t> input_sizes(math_expression const & expressions) const;
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const &);
};

View File

@@ -22,12 +22,12 @@ public:
class elementwise_2d : public base_impl<elementwise_2d, elementwise_2d_parameters>
{
private:
int is_invalid_impl(driver::Device const &, math_expression const &) const;
std::string generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mapping) const;
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const & mapping) const;
public:
elementwise_2d(parameters_type const & parameters, binding_policy_t binding_policy = BIND_INDEPENDENT);
elementwise_2d(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_INDEPENDENT);
std::vector<int_t> input_sizes(math_expression const & expressions) const;
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const &);
};

View File

@@ -41,17 +41,17 @@ struct matrix_product_parameters : public base::parameters_type
class matrix_product : public base_impl<matrix_product, matrix_product_parameters>
{
private:
unsigned int temporary_workspace(math_expression const & expressions) const;
unsigned int lmem_usage(math_expression const & expressions) const;
unsigned int registers_usage(math_expression const & expressions) const;
int is_invalid_impl(driver::Device const &, math_expression const &) const;
std::string generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const &) const;
unsigned int temporary_workspace(expression_tree const & expressions) const;
unsigned int lmem_usage(expression_tree const & expressions) const;
unsigned int registers_usage(expression_tree const & expressions) const;
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const &) const;
void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, array_base const & A, array_base const & B, array_base const & C,
value_scalar const &alpha, value_scalar const &beta, driver::Program const & program, std::string const & suffix, execution_options_type const & options);
std::vector<int_t> infos(math_expression const & expressions, isaac::symbolic::preset::matrix_product::args &arguments) const;
std::vector<int_t> infos(expression_tree const & expressions, isaac::symbolic::preset::matrix_product::args &arguments) const;
public:
matrix_product(matrix_product::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans);
std::vector<int_t> input_sizes(math_expression const & expressions) const;
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const &ctr);
private:
const char A_trans_;

View File

@@ -20,17 +20,17 @@ struct reduce_1d_parameters : public base::parameters_type
class reduce_1d : public base_impl<reduce_1d, reduce_1d_parameters>
{
private:
unsigned int lmem_usage(math_expression const & expressions) const;
int is_invalid_impl(driver::Device const &, math_expression const &) const;
unsigned int temporary_workspace(math_expression const & expressions) const;
unsigned int lmem_usage(expression_tree const & expressions) const;
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
unsigned int temporary_workspace(expression_tree const & expressions) const;
inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_reduce_1d*> exprs,
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const;
std::string generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mapping) const;
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const & mapping) const;
public:
reduce_1d(reduce_1d::parameters_type const & parameters, binding_policy_t binding_policy = BIND_INDEPENDENT);
reduce_1d(unsigned int simd, unsigned int ls, unsigned int ng, fetching_policy_type fetch, binding_policy_t bind = BIND_INDEPENDENT);
std::vector<int_t> input_sizes(math_expression const & expressions) const;
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const &);
private:
std::vector< driver::Buffer > tmp_;

View File

@@ -31,12 +31,12 @@ protected:
};
reduce_2d(reduce_2d::parameters_type const & , reduce_1d_type, binding_policy_t);
private:
int is_invalid_impl(driver::Device const &, math_expression const &) const;
unsigned int lmem_usage(math_expression const &) const;
unsigned int temporary_workspace(math_expression const & expressions) const;
std::string generate_impl(std::string const & suffix, math_expression const &, driver::Device const & device, mapping_type const &) const;
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
unsigned int lmem_usage(expression_tree const &) const;
unsigned int temporary_workspace(expression_tree const & expressions) const;
std::string generate_impl(std::string const & suffix, expression_tree const &, driver::Device const & device, mapping_type const &) const;
public:
virtual std::vector<int_t> input_sizes(math_expression const & expressions) const;
virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const &);
private:
reduce_1d_type reduce_1d_type_;

View File

@@ -7,10 +7,10 @@
namespace isaac
{
/** @brief Executes a math_expression on the given queue for the given models map*/
/** @brief Executes a expression_tree on the given queue for the given models map*/
void execute(execution_handler const & , profiles::map_type &);
/** @brief Executes a math_expression on the default models map*/
/** @brief Executes a expression_tree on the default models map*/
void execute(execution_handler const &);
}

View File

@@ -142,13 +142,13 @@ struct op_element
struct for_idx_t
{
math_expression operator=(value_scalar const & ) const;
math_expression operator=(math_expression const & ) const;
expression_tree operator=(value_scalar const & ) const;
expression_tree operator=(expression_tree const & ) const;
math_expression operator+=(value_scalar const & ) const;
math_expression operator-=(value_scalar const & ) const;
math_expression operator*=(value_scalar const & ) const;
math_expression operator/=(value_scalar const & ) const;
expression_tree operator+=(value_scalar const & ) const;
expression_tree operator-=(value_scalar const & ) const;
expression_tree operator*=(value_scalar const & ) const;
expression_tree operator/=(value_scalar const & ) const;
int level;
};
@@ -171,7 +171,7 @@ struct tree_node
{
std::size_t node_index;
values_holder vscalar;
isaac::array_base* array;
array_base* array;
for_idx_t for_idx;
};
};
@@ -180,11 +180,11 @@ struct invalid_node{};
void fill(tree_node &x, for_idx_t index);
void fill(tree_node &x, invalid_node);
void fill(tree_node & x, std::size_t node_index);
void fill(tree_node & x, size_t node_index);
void fill(tree_node & x, array_base const & a);
void fill(tree_node & x, value_scalar const & v);
class math_expression
class expression_tree
{
public:
struct node
@@ -197,20 +197,19 @@ public:
typedef std::vector<node> container_type;
public:
math_expression(value_scalar const &lhs, for_idx_t const &rhs, const op_element &op, const numeric_type &dtype);
math_expression(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op);
math_expression(for_idx_t const &lhs, value_scalar const &rhs, const op_element &op, const numeric_type &dtype);
expression_tree(value_scalar const &lhs, for_idx_t const &rhs, const op_element &op, const numeric_type &dtype);
expression_tree(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op);
expression_tree(for_idx_t const &lhs, value_scalar const &rhs, const op_element &op, const numeric_type &dtype);
template<class LT, class RT>
math_expression(LT const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
expression_tree(LT const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
template<class RT>
math_expression(math_expression const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
expression_tree(expression_tree const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
template<class LT>
math_expression(LT const & lhs, math_expression const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
math_expression(math_expression const & lhs, math_expression const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
expression_tree(LT const & lhs, expression_tree const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
expression_tree(expression_tree const & lhs, expression_tree const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape);
shape_t shape() const;
math_expression& reshape(int_t size1, int_t size2=1);
int_t dim() const;
container_type & tree();
container_type const & tree() const;
@@ -218,8 +217,8 @@ public:
driver::Context const & context() const;
numeric_type const & dtype() const;
math_expression operator-();
math_expression operator!();
expression_tree operator-();
expression_tree operator!();
private:
container_type tree_;
std::size_t root_;
@@ -279,24 +278,24 @@ struct compilation_options_type
class execution_handler
{
public:
execution_handler(math_expression const & x, execution_options_type const& execution_options = execution_options_type(),
execution_handler(expression_tree const & x, execution_options_type const& execution_options = execution_options_type(),
dispatcher_options_type const & dispatcher_options = dispatcher_options_type(),
compilation_options_type const & compilation_options = compilation_options_type())
: x_(x), execution_options_(execution_options), dispatcher_options_(dispatcher_options), compilation_options_(compilation_options){}
execution_handler(math_expression const & x, execution_handler const & other) : x_(x), execution_options_(other.execution_options_), dispatcher_options_(other.dispatcher_options_), compilation_options_(other.compilation_options_){}
math_expression const & x() const { return x_; }
execution_handler(expression_tree const & x, execution_handler const & other) : x_(x), execution_options_(other.execution_options_), dispatcher_options_(other.dispatcher_options_), compilation_options_(other.compilation_options_){}
expression_tree const & x() const { return x_; }
execution_options_type const & execution_options() const { return execution_options_; }
dispatcher_options_type const & dispatcher_options() const { return dispatcher_options_; }
compilation_options_type const & compilation_options() const { return compilation_options_; }
private:
math_expression x_;
expression_tree x_;
execution_options_type execution_options_;
dispatcher_options_type dispatcher_options_;
compilation_options_type compilation_options_;
};
math_expression::node const & lhs_most(math_expression::container_type const & array_base, math_expression::node const & init);
math_expression::node const & lhs_most(math_expression::container_type const & array_base, size_t root);
expression_tree::node const & lhs_most(expression_tree::container_type const & array_base, expression_tree::node const & init);
expression_tree::node const & lhs_most(expression_tree::container_type const & array_base, size_t root);
}

View File

@@ -9,8 +9,8 @@ namespace isaac
std::string to_string(node_type const & f);
std::string to_string(tree_node const & e);
std::ostream & operator<<(std::ostream & os, math_expression::node const & s_node);
std::string to_string(isaac::math_expression const & s);
std::ostream & operator<<(std::ostream & os, expression_tree::node const & s_node);
std::string to_string(isaac::expression_tree const & s);
}

View File

@@ -34,10 +34,10 @@ public:
};
private:
static void handle_node( math_expression::container_type const &tree, size_t rootidx, args & a);
static void handle_node( expression_tree::container_type const &tree, size_t rootidx, args & a);
public:
static args check(math_expression::container_type const &tree, size_t rootidx);
static args check(expression_tree::container_type const &tree, size_t rootidx);
};
}

View File

@@ -18,13 +18,13 @@ ISAACAPI typename std::conditional<std::is_arithmetic<T>::value, value_scalar, T
{ return wrap_generic(x); }
template<typename T, typename... Args>
ISAACAPI math_expression make_tuple(driver::Context const & context, T const & x, Args... args)
{ return math_expression(wrap_generic(x), make_tuple(context, args...), op_element(BINARY_TYPE_FAMILY, PAIR_TYPE), context, numeric_type_of(x), {1}); }
ISAACAPI expression_tree make_tuple(driver::Context const & context, T const & x, Args... args)
{ return expression_tree(wrap_generic(x), make_tuple(context, args...), op_element(BINARY_TYPE_FAMILY, PAIR_TYPE), context, numeric_type_of(x), {1}); }
inline value_scalar tuple_get(math_expression::container_type const & tree, size_t root, size_t idx)
inline value_scalar tuple_get(expression_tree::container_type const & tree, size_t root, size_t idx)
{
for(unsigned int i = 0 ; i < idx ; ++i){
math_expression::node node = tree[root];
expression_tree::node node = tree[root];
if(node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
root = node.rhs.node_index;
else

View File

@@ -9,7 +9,7 @@ namespace isaac
{
class scalar;
class math_expression;
class expression_tree;
union ISAACAPI values_holder
{
@@ -46,7 +46,7 @@ public:
#undef ISAAC_INSTANTIATE
value_scalar(values_holder values, numeric_type dtype);
explicit value_scalar(scalar const &);
explicit value_scalar(math_expression const &);
explicit value_scalar(expression_tree const &);
explicit value_scalar(numeric_type dtype = INVALID_NUMERIC_TYPE);
values_holder values() const;

View File

@@ -169,7 +169,7 @@ array_base & array_base::operator=(array_base const & rhs)
{
if(shape_.min()==0) return *this;
assert(dtype_ == rhs.dtype());
math_expression expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
expression_tree expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
execute(execution_handler(expression));
return *this;
}
@@ -178,7 +178,7 @@ array_base & array_base::operator=(value_scalar const & rhs)
{
if(shape_.min()==0) return *this;
assert(dtype_ == rhs.dtype());
math_expression expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
expression_tree expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
execute(execution_handler(expression));
return *this;
}
@@ -188,12 +188,12 @@ array_base& array_base::operator=(execution_handler const & c)
{
if(shape_.min()==0) return *this;
assert(dtype_ == c.x().dtype());
math_expression expression(*this, c.x(), op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
expression_tree expression(*this, c.x(), op_element(BINARY_TYPE_FAMILY, ASSIGN_TYPE), context_, dtype_, shape_);
execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
return *this;
}
array_base & array_base::operator=(math_expression const & rhs)
array_base & array_base::operator=(expression_tree const & rhs)
{
return *this = execution_handler(rhs);
}
@@ -227,54 +227,54 @@ INSTANTIATE(double);
math_expression array_base::operator-()
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
expression_tree array_base::operator-()
{ return expression_tree(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
math_expression array_base::operator!()
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), context_, INT_TYPE, shape_); }
expression_tree array_base::operator!()
{ return expression_tree(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), context_, INT_TYPE, shape_); }
//
array_base & array_base::operator+=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator+=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator+=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), rhs.context(), dtype_, shape_); }
array_base & array_base::operator+=(expression_tree const & rhs)
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, ADD_TYPE), rhs.context(), dtype_, shape_); }
//----
array_base & array_base::operator-=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator-=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator-=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), rhs.context(), dtype_, shape_); }
array_base & array_base::operator-=(expression_tree const & rhs)
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, SUB_TYPE), rhs.context(), dtype_, shape_); }
//----
array_base & array_base::operator*=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator*=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator*=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), rhs.context(), dtype_, shape_); }
array_base & array_base::operator*=(expression_tree const & rhs)
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, MULT_TYPE), rhs.context(), dtype_, shape_); }
//----
array_base & array_base::operator/=(value_scalar const & rhs)
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator/=(array_base const & rhs)
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), context_, dtype_, shape_); }
array_base & array_base::operator/=(math_expression const & rhs)
{ return *this = math_expression(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), rhs.context(), dtype_, shape_); }
array_base & array_base::operator/=(expression_tree const & rhs)
{ return *this = expression_tree(*this, rhs, op_element(BINARY_TYPE_FAMILY, DIV_TYPE), rhs.context(), dtype_, shape_); }
/*--- Indexing operators -----*/
//---------------------------------------
math_expression array_base::operator[](for_idx_t idx) const
expression_tree array_base::operator[](for_idx_t idx) const
{
return math_expression(*this, idx, op_element(BINARY_TYPE_FAMILY, ACCESS_INDEX_TYPE), context_, dtype_, {1});
return expression_tree(*this, idx, op_element(BINARY_TYPE_FAMILY, ACCESS_INDEX_TYPE), context_, dtype_, {1});
}
scalar array_base::operator [](int_t idx)
@@ -324,7 +324,7 @@ view array_base::operator()(slice const & si, slice const & sj)
//---------------------------------------
/*--- array ---*/
array::array(math_expression const & proxy) : array_base(execution_handler(proxy)) {}
array::array(expression_tree const & proxy) : array_base(execution_handler(proxy)) {}
array::array(array_base const & other): array_base(other.dtype(), other.shape(), other.context())
{ *this = other; }
@@ -379,7 +379,7 @@ scalar::scalar(value_scalar value, driver::Context const & context) : array_base
scalar::scalar(numeric_type dtype, driver::Context const & context) : array_base(1, dtype, context)
{ }
scalar::scalar(math_expression const & proxy) : array_base(proxy){ }
scalar::scalar(expression_tree const & proxy) : array_base(proxy){ }
void scalar::inject(values_holder & v) const
{
@@ -511,53 +511,53 @@ shape_t broadcast(shape_t const & a, shape_t const & b)
}
#define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \
math_expression OPNAME (array_base const & x, math_expression const & y) \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
expression_tree OPNAME (array_base const & x, expression_tree const & y) \
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
\
math_expression OPNAME (array_base const & x, array_base const & y) \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); }\
expression_tree OPNAME (array_base const & x, array_base const & y) \
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); }\
\
math_expression OPNAME (array_base const & x, value_scalar const & y) \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
expression_tree OPNAME (array_base const & x, value_scalar const & y) \
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\
math_expression OPNAME (array_base const & x, for_idx_t const & y) \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
expression_tree OPNAME (array_base const & x, for_idx_t const & y) \
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\
\
math_expression OPNAME (math_expression const & x, math_expression const & y) \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
expression_tree OPNAME (expression_tree const & x, expression_tree const & y) \
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
\
math_expression OPNAME (math_expression const & x, array_base const & y) \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
expression_tree OPNAME (expression_tree const & x, array_base const & y) \
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, broadcast(x.shape(), y.shape())); } \
\
math_expression OPNAME (math_expression const & x, value_scalar const & y) \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
expression_tree OPNAME (expression_tree const & x, value_scalar const & y) \
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\
math_expression OPNAME (math_expression const & x, for_idx_t const & y) \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
expression_tree OPNAME (expression_tree const & x, for_idx_t const & y) \
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\
\
math_expression OPNAME (value_scalar const & y, math_expression const & x) \
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
expression_tree OPNAME (value_scalar const & y, expression_tree const & x) \
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\
math_expression OPNAME (value_scalar const & y, array_base const & x) \
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
expression_tree OPNAME (value_scalar const & y, array_base const & x) \
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\
math_expression OPNAME (value_scalar const & x, for_idx_t const & y) \
{ return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); }\
expression_tree OPNAME (value_scalar const & x, for_idx_t const & y) \
{ return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); }\
\
\
math_expression OPNAME (for_idx_t const & y, math_expression const & x) \
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
expression_tree OPNAME (for_idx_t const & y, expression_tree const & x) \
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); } \
\
math_expression OPNAME (for_idx_t const & y, value_scalar const & x) \
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); } \
expression_tree OPNAME (for_idx_t const & y, value_scalar const & x) \
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), DTYPE); } \
\
math_expression OPNAME (for_idx_t const & y, array_base const & x) \
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
expression_tree OPNAME (for_idx_t const & y, array_base const & x) \
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP), x.context(), DTYPE, x.shape()); }\
\
math_expression OPNAME (for_idx_t const & y, for_idx_t const & x) \
{ return math_expression(y, x, op_element(BINARY_TYPE_FAMILY, OP)); }
expression_tree OPNAME (for_idx_t const & y, for_idx_t const & x) \
{ return expression_tree(y, x, op_element(BINARY_TYPE_FAMILY, OP)); }
DEFINE_ELEMENT_BINARY_OPERATOR(ADD_TYPE, operator +, x.dtype())
@@ -580,39 +580,39 @@ DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_EQ_TYPE, operator ==, INT_TYPE)
DEFINE_ELEMENT_BINARY_OPERATOR(ELEMENT_NEQ_TYPE, operator !=, INT_TYPE)
#define DEFINE_OUTER(LTYPE, RTYPE) \
math_expression outer(LTYPE const & x, RTYPE const & y)\
expression_tree outer(LTYPE const & x, RTYPE const & y)\
{\
assert(x.dim()<=1 && y.dim()<=1);\
if(x.dim()<1 || y.dim()<1)\
return x*y;\
return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OUTER_PROD_TYPE), x.context(), x.dtype(), {x.shape().max(), y.shape().max()} );\
return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OUTER_PROD_TYPE), x.context(), x.dtype(), {x.shape().max(), y.shape().max()} );\
}\
DEFINE_OUTER(array_base, array_base)
DEFINE_OUTER(math_expression, array_base)
DEFINE_OUTER(array_base, math_expression)
DEFINE_OUTER(math_expression, math_expression)
DEFINE_OUTER(expression_tree, array_base)
DEFINE_OUTER(array_base, expression_tree)
DEFINE_OUTER(expression_tree, expression_tree)
#undef DEFINE_ELEMENT_BINARY_OPERATOR
#define DEFINE_ROT(LTYPE, RTYPE, CTYPE, STYPE)\
math_expression rot(LTYPE const & x, RTYPE const & y, CTYPE const & c, STYPE const & s)\
expression_tree rot(LTYPE const & x, RTYPE const & y, CTYPE const & c, STYPE const & s)\
{ return fuse(assign(x, c*x + s*y), assign(y, c*y - s*x)); }
DEFINE_ROT(array_base, array_base, scalar, scalar)
DEFINE_ROT(math_expression, array_base, scalar, scalar)
DEFINE_ROT(array_base, math_expression, scalar, scalar)
DEFINE_ROT(math_expression, math_expression, scalar, scalar)
DEFINE_ROT(expression_tree, array_base, scalar, scalar)
DEFINE_ROT(array_base, expression_tree, scalar, scalar)
DEFINE_ROT(expression_tree, expression_tree, scalar, scalar)
DEFINE_ROT(array_base, array_base, value_scalar, value_scalar)
DEFINE_ROT(math_expression, array_base, value_scalar, value_scalar)
DEFINE_ROT(array_base, math_expression, value_scalar, value_scalar)
DEFINE_ROT(math_expression, math_expression, value_scalar, value_scalar)
DEFINE_ROT(expression_tree, array_base, value_scalar, value_scalar)
DEFINE_ROT(array_base, expression_tree, value_scalar, value_scalar)
DEFINE_ROT(expression_tree, expression_tree, value_scalar, value_scalar)
DEFINE_ROT(array_base, array_base, math_expression, math_expression)
DEFINE_ROT(math_expression, array_base, math_expression, math_expression)
DEFINE_ROT(array_base, math_expression, math_expression, math_expression)
DEFINE_ROT(math_expression, math_expression, math_expression, math_expression)
DEFINE_ROT(array_base, array_base, expression_tree, expression_tree)
DEFINE_ROT(expression_tree, array_base, expression_tree, expression_tree)
DEFINE_ROT(array_base, expression_tree, expression_tree, expression_tree)
DEFINE_ROT(expression_tree, expression_tree, expression_tree, expression_tree)
@@ -621,11 +621,11 @@ DEFINE_ROT(math_expression, math_expression, math_expression, math_expression)
/*--- Math Operators----*/
//---------------------------------------
#define DEFINE_ELEMENT_UNARY_OPERATOR(OP, OPNAME) \
math_expression OPNAME (array_base const & x) \
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
expression_tree OPNAME (array_base const & x) \
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }\
\
math_expression OPNAME (math_expression const & x) \
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
expression_tree OPNAME (expression_tree const & x) \
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, OP), x.context(), x.dtype(), x.shape()); }
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?FABS_TYPE:ABS_TYPE, abs)
DEFINE_ELEMENT_UNARY_OPERATOR(ACOS_TYPE, acos)
@@ -669,14 +669,14 @@ inline operation_type casted(numeric_type dtype)
}
}
math_expression cast(array_base const & x, numeric_type dtype)
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
expression_tree cast(array_base const & x, numeric_type dtype)
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
math_expression cast(math_expression const & x, numeric_type dtype)
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
expression_tree cast(expression_tree const & x, numeric_type dtype)
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, casted(dtype)), x.context(), dtype, x.shape()); }
isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return math_expression(value_scalar(1), value_scalar(0), op_element(UNARY_TYPE_FAMILY, VDIAG_TYPE), ctx, dtype, {M, N}); }
isaac::expression_tree eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return expression_tree(value_scalar(1), value_scalar(0), op_element(UNARY_TYPE_FAMILY, VDIAG_TYPE), ctx, dtype, {M, N}); }
array diag(array_base & x, int offset)
{
@@ -688,8 +688,8 @@ array diag(array_base & x, int offset)
}
isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(UNARY_TYPE_FAMILY, ADD_TYPE), ctx, dtype, {M, N}); }
isaac::expression_tree zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return expression_tree(value_scalar(0, dtype), invalid_node(), op_element(UNARY_TYPE_FAMILY, ADD_TYPE), ctx, dtype, {M, N}); }
inline shape_t flip(shape_t const & shape)
{
@@ -702,77 +702,77 @@ inline shape_t flip(shape_t const & shape)
//inline size4 prod(size4 const & shape1, size4 const & shape2)
//{ return size4(shape1[0]*shape2[0], shape1[1]*shape2[1]);}
math_expression trans(array_base const & x) \
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
expression_tree trans(array_base const & x) \
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }\
\
math_expression trans(math_expression const & x) \
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
expression_tree trans(expression_tree const & x) \
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, TRANS_TYPE), x.context(), x.dtype(), flip(x.shape())); }
math_expression repmat(array_base const & A, int_t const & rep1, int_t const & rep2)
expression_tree repmat(array_base const & A, int_t const & rep1, int_t const & rep2)
{
int_t sub1 = A.shape()[0];
int_t sub2 = A.dim()==2?A.shape()[1]:1;
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
return expression_tree(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
}
math_expression repmat(math_expression const & A, int_t const & rep1, int_t const & rep2)
expression_tree repmat(expression_tree const & A, int_t const & rep1, int_t const & rep2)
{
int_t sub1 = A.shape()[0];
int_t sub2 = A.dim()==2?A.shape()[1]:1;
return math_expression(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
return expression_tree(A, make_tuple(A.context(), rep1, rep2, sub1, sub2), op_element(BINARY_TYPE_FAMILY, REPEAT_TYPE), A.context(), A.dtype(), {rep1*sub1, rep2*sub2});
}
#define DEFINE_ACCESS_ROW(TYPEA, TYPEB) \
math_expression row(TYPEA const & x, TYPEB const & i)\
{ return math_expression(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_ROW_TYPE), x.context(), x.dtype(), {x.shape()[1]}); }
expression_tree row(TYPEA const & x, TYPEB const & i)\
{ return expression_tree(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_ROW_TYPE), x.context(), x.dtype(), {x.shape()[1]}); }
DEFINE_ACCESS_ROW(array_base, value_scalar)
DEFINE_ACCESS_ROW(array_base, for_idx_t)
DEFINE_ACCESS_ROW(array_base, math_expression)
DEFINE_ACCESS_ROW(array_base, expression_tree)
DEFINE_ACCESS_ROW(math_expression, value_scalar)
DEFINE_ACCESS_ROW(math_expression, for_idx_t)
DEFINE_ACCESS_ROW(math_expression, math_expression)
DEFINE_ACCESS_ROW(expression_tree, value_scalar)
DEFINE_ACCESS_ROW(expression_tree, for_idx_t)
DEFINE_ACCESS_ROW(expression_tree, expression_tree)
#define DEFINE_ACCESS_COL(TYPEA, TYPEB) \
math_expression col(TYPEA const & x, TYPEB const & i)\
{ return math_expression(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_COLUMN_TYPE), x.context(), x.dtype(), {x.shape()[0]}); }
expression_tree col(TYPEA const & x, TYPEB const & i)\
{ return expression_tree(x, i, op_element(UNARY_TYPE_FAMILY, MATRIX_COLUMN_TYPE), x.context(), x.dtype(), {x.shape()[0]}); }
DEFINE_ACCESS_COL(array_base, value_scalar)
DEFINE_ACCESS_COL(array_base, for_idx_t)
DEFINE_ACCESS_COL(array_base, math_expression)
DEFINE_ACCESS_COL(array_base, expression_tree)
DEFINE_ACCESS_COL(math_expression, value_scalar)
DEFINE_ACCESS_COL(math_expression, for_idx_t)
DEFINE_ACCESS_COL(math_expression, math_expression)
DEFINE_ACCESS_COL(expression_tree, value_scalar)
DEFINE_ACCESS_COL(expression_tree, for_idx_t)
DEFINE_ACCESS_COL(expression_tree, expression_tree)
////---------------------------------------
///*--- Reductions ---*/
////---------------------------------------
#define DEFINE_REDUCTION(OP, OPNAME)\
math_expression OPNAME(array_base const & x, int_t axis)\
expression_tree OPNAME(array_base const & x, int_t axis)\
{\
if(axis < -1 || axis > x.dim())\
throw std::out_of_range("The axis entry is out of bounds");\
else if(axis==-1)\
return math_expression(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
return expression_tree(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
else if(axis==0)\
return math_expression(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
return expression_tree(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
else\
return math_expression(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
return expression_tree(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
}\
\
math_expression OPNAME(math_expression const & x, int_t axis)\
expression_tree OPNAME(expression_tree const & x, int_t axis)\
{\
if(axis < -1 || axis > x.dim())\
throw std::out_of_range("The axis entry is out of bounds");\
if(axis==-1)\
return math_expression(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
return expression_tree(x, invalid_node(), op_element(VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {1});\
else if(axis==0)\
return math_expression(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
return expression_tree(x, invalid_node(), op_element(COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[1]});\
else\
return math_expression(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
return expression_tree(x, invalid_node(), op_element(ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), {x.shape()[0]});\
}
DEFINE_REDUCTION(ADD_TYPE, sum)
@@ -786,51 +786,51 @@ DEFINE_REDUCTION(ELEMENT_ARGMIN_TYPE, argmin)
namespace detail
{
math_expression matmatprod(array_base const & A, array_base const & B)
expression_tree matmatprod(array_base const & A, array_base const & B)
{
shape_t shape{A.shape()[0], B.shape()[1]};
return math_expression(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, MATRIX_PRODUCT_NN_TYPE), A.context(), A.dtype(), shape);
return expression_tree(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, MATRIX_PRODUCT_NN_TYPE), A.context(), A.dtype(), shape);
}
math_expression matmatprod(math_expression const & A, array_base const & B)
expression_tree matmatprod(expression_tree const & A, array_base const & B)
{
operation_type type = MATRIX_PRODUCT_NN_TYPE;
shape_t shape{A.shape()[0], B.shape()[1]};
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
expression_tree::node & A_root = const_cast<expression_tree::node &>(A.tree()[A.root()]);
bool A_trans = A_root.op.type==TRANS_TYPE;
if(A_trans){
type = MATRIX_PRODUCT_TN_TYPE;
}
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
expression_tree res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
expression_tree::node & res_root = const_cast<expression_tree::node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs;
return res;
}
math_expression matmatprod(array_base const & A, math_expression const & B)
expression_tree matmatprod(array_base const & A, expression_tree const & B)
{
operation_type type = MATRIX_PRODUCT_NN_TYPE;
shape_t shape{A.shape()[0], B.shape()[1]};
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
expression_tree::node & B_root = const_cast<expression_tree::node &>(B.tree()[B.root()]);
bool B_trans = B_root.op.type==TRANS_TYPE;
if(B_trans){
type = MATRIX_PRODUCT_NT_TYPE;
}
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
expression_tree res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
expression_tree::node & res_root = const_cast<expression_tree::node &>(res.tree()[res.root()]);
if(B_trans) res_root.rhs = B_root.lhs;
return res;
}
math_expression matmatprod(math_expression const & A, math_expression const & B)
expression_tree matmatprod(expression_tree const & A, expression_tree const & B)
{
operation_type type = MATRIX_PRODUCT_NN_TYPE;
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
math_expression::node & B_root = const_cast<math_expression::node &>(B.tree()[B.root()]);
expression_tree::node & A_root = const_cast<expression_tree::node &>(A.tree()[A.root()]);
expression_tree::node & B_root = const_cast<expression_tree::node &>(B.tree()[B.root()]);
shape_t shape{A.shape()[0], B.shape()[1]};
bool A_trans = A_root.op.type==TRANS_TYPE;
@@ -841,15 +841,15 @@ namespace detail
else if(!A_trans && B_trans) type = MATRIX_PRODUCT_NT_TYPE;
else type = MATRIX_PRODUCT_NN_TYPE;
math_expression res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
math_expression::node & res_root = const_cast<math_expression::node &>(res.tree()[res.root()]);
expression_tree res(A, B, op_element(MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
expression_tree::node & res_root = const_cast<expression_tree::node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs;
if(B_trans) res_root.rhs = B_root.lhs;
return res;
}
template<class T>
math_expression matvecprod(array_base const & A, T const & x)
expression_tree matvecprod(array_base const & A, T const & x)
{
int_t M = A.shape()[0];
int_t N = A.shape()[1];
@@ -857,11 +857,11 @@ namespace detail
}
template<class T>
math_expression matvecprod(math_expression const & A, T const & x)
expression_tree matvecprod(expression_tree const & A, T const & x)
{
int_t M = A.shape()[0];
int_t N = A.shape()[1];
math_expression::node & A_root = const_cast<math_expression::node &>(A.tree()[A.root()]);
expression_tree::node & A_root = const_cast<expression_tree::node &>(A.tree()[A.root()]);
bool A_trans = A_root.op.type==TRANS_TYPE;
while(A_root.lhs.subtype==COMPOSITE_OPERATOR_TYPE){
A_root = A.tree()[A_root.lhs.node_index];
@@ -869,7 +869,7 @@ namespace detail
}
if(A_trans)
{
math_expression tmp(A, repmat(x, 1, M), op_element(BINARY_TYPE_FAMILY, ELEMENT_PROD_TYPE), A.context(), A.dtype(), {N, M});
expression_tree tmp(A, repmat(x, 1, M), op_element(BINARY_TYPE_FAMILY, ELEMENT_PROD_TYPE), A.context(), A.dtype(), {N, M});
//Remove trans
tmp.tree()[tmp.root()].lhs = A.tree()[A.root()].lhs;
return sum(tmp, 0);
@@ -889,17 +889,17 @@ ISAACAPI void swap(view x, view y)
}
//Reshape
math_expression reshape(array_base const & x, shape_t const & shape)
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
expression_tree reshape(array_base const & x, shape_t const & shape)
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
math_expression reshape(math_expression const & x, shape_t const & shape)
{ return math_expression(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
expression_tree reshape(expression_tree const & x, shape_t const & shape)
{ return expression_tree(x, invalid_node(), op_element(UNARY_TYPE_FAMILY, RESHAPE_TYPE), x.context(), x.dtype(), shape); }
math_expression ravel(array_base const & x)
expression_tree ravel(array_base const & x)
{ return reshape(x, {x.shape().prod()}); }
#define DEFINE_DOT(LTYPE, RTYPE) \
math_expression dot(LTYPE const & x, RTYPE const & y)\
expression_tree dot(LTYPE const & x, RTYPE const & y)\
{\
numeric_type dtype = x.dtype();\
driver::Context const & context = x.context();\
@@ -908,7 +908,7 @@ math_expression dot(LTYPE const & x, RTYPE const & y)\
if(x.dim()==2 && x.shape()[1]==0)\
return zeros(x.shape()[0], y.shape()[1], dtype, context);\
if(x.shape()[0]==0 || (y.dim()==2 && y.shape()[1]==0))\
return math_expression(invalid_node(), invalid_node(), op_element(UNARY_TYPE_FAMILY, INVALID_TYPE), context, dtype, {0});\
return expression_tree(invalid_node(), invalid_node(), op_element(UNARY_TYPE_FAMILY, INVALID_TYPE), context, dtype, {0});\
if(x.dim()==1 && y.dim()==1)\
return sum(x*y);\
if(x.dim()==2 && x.shape()[0]==1 && y.dim()==1){\
@@ -940,15 +940,15 @@ math_expression dot(LTYPE const & x, RTYPE const & y)\
}
DEFINE_DOT(array_base, array_base)
DEFINE_DOT(math_expression, array_base)
DEFINE_DOT(array_base, math_expression)
DEFINE_DOT(math_expression, math_expression)
DEFINE_DOT(expression_tree, array_base)
DEFINE_DOT(array_base, expression_tree)
DEFINE_DOT(expression_tree, expression_tree)
#undef DEFINE_DOT
#define DEFINE_NORM(TYPE)\
math_expression norm(TYPE const & x, unsigned int order)\
expression_tree norm(TYPE const & x, unsigned int order)\
{\
assert(order > 0 && order < 3);\
switch(order)\
@@ -959,21 +959,21 @@ math_expression norm(TYPE const & x, unsigned int order)\
}
DEFINE_NORM(array_base)
DEFINE_NORM(math_expression)
DEFINE_NORM(expression_tree)
#undef DEFINE_NORM
/*--- Fusion ----*/
math_expression fuse(math_expression const & x, math_expression const & y)
expression_tree fuse(expression_tree const & x, expression_tree const & y)
{
assert(x.context()==y.context());
return math_expression(x, y, op_element(BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape());
return expression_tree(x, y, op_element(BINARY_TYPE_FAMILY, OPERATOR_FUSE), x.context(), x.dtype(), x.shape());
}
/*--- For loops ---*/
ISAACAPI math_expression sfor(math_expression const & start, math_expression const & end, math_expression const & inc, math_expression const & x)
ISAACAPI expression_tree sfor(expression_tree const & start, expression_tree const & end, expression_tree const & inc, expression_tree const & x)
{
return math_expression(x, make_tuple(x.context(), start, end, inc), op_element(UNARY_TYPE_FAMILY, SFOR_TYPE), x.context(), x.dtype(), x.shape());
return expression_tree(x, make_tuple(x.context(), start, end, inc), op_element(UNARY_TYPE_FAMILY, SFOR_TYPE), x.context(), x.dtype(), x.shape());
}
@@ -1160,7 +1160,7 @@ std::ostream& operator<<(std::ostream & os, array_base const & a)
return os;
}
ISAACAPI std::ostream& operator<<(std::ostream & oss, math_expression const & expression)
ISAACAPI std::ostream& operator<<(std::ostream & oss, expression_tree const & expression)
{
return oss << array(expression);
}

View File

@@ -51,8 +51,8 @@ void mapped_object::register_attribute(std::string & attribute, std::string cons
keywords_[key] = attribute;
}
mapped_object::node_info::node_info(mapping_type const * _mapping, isaac::math_expression const * _math_expression, size_t _root_idx) :
mapping(_mapping), math_expression(_math_expression), root_idx(_root_idx) { }
mapped_object::node_info::node_info(mapping_type const * _mapping, isaac::expression_tree const * _expression_tree, size_t _root_idx) :
mapping(_mapping), expression_tree(_expression_tree), root_idx(_root_idx) { }
mapped_object::mapped_object(std::string const & scalartype, std::string const & name, std::string const & type_key) : type_key_(type_key), name_(name)
{
@@ -92,10 +92,10 @@ std::string mapped_object::evaluate(std::map<std::string, std::string> const & a
return process(accessors.at(type_key_));
}
mapped_object& get(math_expression::container_type const & tree, size_t root, mapping_type const & mapping, size_t idx)
mapped_object& get(expression_tree::container_type const & tree, size_t root, mapping_type const & mapping, size_t idx)
{
for(unsigned int i = 0 ; i < idx ; ++i){
math_expression::node node = tree[root];
expression_tree::node node = tree[root];
if(node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
root = node.rhs.node_index;
else
@@ -108,12 +108,12 @@ binary_leaf::binary_leaf(mapped_object::node_info info) : info_(info){ }
void binary_leaf::process_recursive(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors, std::set<std::string> & already_fetched)
{
process(stream, leaf, accessors, *info_.math_expression, info_.root_idx, *info_.mapping, already_fetched);
process(stream, leaf, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping, already_fetched);
}
std::string binary_leaf::evaluate_recursive(leaf_t leaf, std::map<std::string, std::string> const & accessors)
{
return evaluate(leaf, accessors, *info_.math_expression, info_.root_idx, *info_.mapping);
return evaluate(leaf, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping);
}
@@ -127,11 +127,11 @@ mapped_reduce::mapped_reduce(std::string const & scalartype, unsigned int id, no
size_t mapped_reduce::root_idx() const
{ return info_.root_idx; }
isaac::math_expression const & mapped_reduce::math_expression() const
{ return *info_.math_expression; }
isaac::expression_tree const & mapped_reduce::expression_tree() const
{ return *info_.expression_tree; }
math_expression::node mapped_reduce::root_node() const
{ return math_expression().tree()[root_idx()]; }
expression_tree::node mapped_reduce::root_node() const
{ return expression_tree().tree()[root_idx()]; }
bool mapped_reduce::is_index_reduction() const
{
@@ -144,7 +144,7 @@ bool mapped_reduce::is_index_reduction() const
op_element mapped_reduce::root_op() const
{
return info_.math_expression->tree()[info_.root_idx].op;
return info_.expression_tree->tree()[info_.root_idx].op;
}
@@ -214,10 +214,10 @@ mapped_array::mapped_array(std::string const & scalartype, unsigned int id, std:
void mapped_vdiag::postprocess(std::string &res) const
{
std::map<std::string, std::string> accessors;
tools::find_and_replace(res, "#diag_offset", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping));
tools::find_and_replace(res, "#diag_offset", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping));
accessors["arrayn"] = res;
accessors["host_scalar"] = res;
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping);
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping);
}
mapped_vdiag::mapped_vdiag(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "vdiag"), binary_leaf(info){}
@@ -226,10 +226,10 @@ mapped_vdiag::mapped_vdiag(std::string const & scalartype, unsigned int id, node
void mapped_array_access::postprocess(std::string &res) const
{
std::map<std::string, std::string> accessors;
tools::find_and_replace(res, "#index", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping));
tools::find_and_replace(res, "#index", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping));
accessors["arrayn"] = res;
accessors["arraynn"] = res;
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping);
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping);
}
mapped_array_access::mapped_array_access(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "array_access"), binary_leaf(info)
@@ -239,9 +239,9 @@ mapped_array_access::mapped_array_access(std::string const & scalartype, unsigne
void mapped_matrix_row::postprocess(std::string &res) const
{
std::map<std::string, std::string> accessors;
tools::find_and_replace(res, "#row", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping));
tools::find_and_replace(res, "#row", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping));
accessors["arraynn"] = res;
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping);
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping);
}
mapped_matrix_row::mapped_matrix_row(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "matrix_row"), binary_leaf(info)
@@ -251,9 +251,9 @@ mapped_matrix_row::mapped_matrix_row(std::string const & scalartype, unsigned in
void mapped_matrix_column::postprocess(std::string &res) const
{
std::map<std::string, std::string> accessors;
tools::find_and_replace(res, "#column", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping));
tools::find_and_replace(res, "#column", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping));
accessors["arraynn"] = res;
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping);
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping);
}
mapped_matrix_column::mapped_matrix_column(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "matrix_column"), binary_leaf(info)
@@ -265,7 +265,7 @@ mapped_matrix_column::mapped_matrix_column(std::string const & scalartype, unsig
//
char mapped_repeat::get_type(node_info const & info)
{
math_expression::container_type const & tree = info.math_expression->tree();
expression_tree::container_type const & tree = info.expression_tree->tree();
size_t tuple_root = tree[info.root_idx].rhs.node_index;
int sub0 = tuple_get(tree, tuple_root, 2);
@@ -283,7 +283,7 @@ char mapped_repeat::get_type(node_info const & info)
void mapped_repeat::postprocess(std::string &res) const
{
std::map<std::string, std::string> accessors;
math_expression::container_type const & tree = info_.math_expression->tree();
expression_tree::container_type const & tree = info_.expression_tree->tree();
size_t tuple_root = tree[info_.root_idx].rhs.node_index;
tools::find_and_replace(res, "#rep0", get(tree, tuple_root, *info_.mapping, 0).process("#name"));
@@ -312,7 +312,7 @@ void mapped_repeat::postprocess(std::string &res) const
accessors["array1n"] = res;
accessors["arrayn1"] = res;
accessors["arraynn"] = res;
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping);
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping);
}
mapped_repeat::mapped_repeat(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "repeat"), binary_leaf(info), type_(get_type(info))
@@ -324,9 +324,9 @@ mapped_repeat::mapped_repeat(std::string const & scalartype, unsigned int id, no
void mapped_matrix_diag::postprocess(std::string &res) const
{
std::map<std::string, std::string> accessors;
tools::find_and_replace(res, "#diag_offset", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping));
tools::find_and_replace(res, "#diag_offset", isaac::evaluate(RHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping));
accessors["arraynn"] = res;
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.math_expression, info_.root_idx, *info_.mapping);
res = isaac::evaluate(LHS_NODE_TYPE, accessors, *info_.expression_tree, info_.root_idx, *info_.mapping);
}
mapped_matrix_diag::mapped_matrix_diag(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "matrix_diag"), binary_leaf(info)
@@ -343,7 +343,7 @@ void mapped_outer::postprocess(std::string &res) const
std::map<std::string, std::string> accessors;
accessors["arrayn"] = "$VALUE{"+i+"}";
accessors["array1"] = "#namereg";
return isaac::evaluate(leaf_, accessors, *i_.math_expression, i_.root_idx, *i_.mapping);
return isaac::evaluate(leaf_, accessors, *i_.expression_tree, i_.root_idx, *i_.mapping);
}
std::string operator()(std::string const &, std::string const &) const{return "";}
private:

View File

@@ -14,12 +14,12 @@ namespace detail
bool is_scalar_reduce_1d(math_expression::node const & node)
bool is_scalar_reduce_1d(expression_tree::node const & node)
{
return node.op.type_family==VECTOR_DOT_TYPE_FAMILY;
}
bool is_vector_reduce_1d(math_expression::node const & node)
bool is_vector_reduce_1d(expression_tree::node const & node)
{
return node.op.type_family==ROWS_DOT_TYPE_FAMILY
|| node.op.type_family==COLUMNS_DOT_TYPE_FAMILY;
@@ -123,18 +123,18 @@ namespace detail
filter_fun::filter_fun(pred_t pred, std::vector<size_t> & out) : pred_(pred), out_(out)
{ }
void filter_fun::operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t leaf) const
void filter_fun::operator()(isaac::expression_tree const & expression_tree, size_t root_idx, leaf_t leaf) const
{
math_expression::node const * root_node = &math_expression.tree()[root_idx];
expression_tree::node const * root_node = &expression_tree.tree()[root_idx];
if (leaf==PARENT_NODE_TYPE && pred_(*root_node))
out_.push_back(root_idx);
}
//
std::vector<size_t> filter_nodes(bool (*pred)(math_expression::node const & node), isaac::math_expression const & math_expression, size_t root, bool inspect)
std::vector<size_t> filter_nodes(bool (*pred)(expression_tree::node const & node), isaac::expression_tree const & expression_tree, size_t root, bool inspect)
{
std::vector<size_t> res;
traverse(math_expression, root, filter_fun(pred, res), inspect);
traverse(expression_tree, root, filter_fun(pred, res), inspect);
return res;
}
@@ -143,9 +143,9 @@ filter_elements_fun::filter_elements_fun(node_type subtype, std::vector<tree_nod
subtype_(subtype), out_(out)
{ }
void filter_elements_fun::operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t) const
void filter_elements_fun::operator()(isaac::expression_tree const & expression_tree, size_t root_idx, leaf_t) const
{
math_expression::node const * root_node = &math_expression.tree()[root_idx];
expression_tree::node const * root_node = &expression_tree.tree()[root_idx];
if (root_node->lhs.subtype==subtype_)
out_.push_back(root_node->lhs);
if (root_node->rhs.subtype==subtype_)
@@ -153,10 +153,10 @@ void filter_elements_fun::operator()(isaac::math_expression const & math_express
}
std::vector<tree_node> filter_elements(node_type subtype, isaac::math_expression const & math_expression)
std::vector<tree_node> filter_elements(node_type subtype, isaac::expression_tree const & expression_tree)
{
std::vector<tree_node> res;
traverse(math_expression, math_expression.root(), filter_elements_fun(subtype, res), true);
traverse(expression_tree, expression_tree.root(), filter_elements_fun(subtype, res), true);
return res;
}
@@ -240,9 +240,9 @@ evaluate_expression_traversal::evaluate_expression_traversal(std::map<std::strin
accessors_(accessors), str_(str), mapping_(mapping)
{ }
void evaluate_expression_traversal::call_before_expansion(isaac::math_expression const & math_expression, std::size_t root_idx) const
void evaluate_expression_traversal::call_before_expansion(isaac::expression_tree const & expression_tree, std::size_t root_idx) const
{
math_expression::node const & root_node = math_expression.tree()[root_idx];
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
if(detail::is_cast(root_node.op))
str_ += mapping_.at(std::make_pair(root_idx, PARENT_NODE_TYPE))->evaluate(accessors_);
else if (( (root_node.op.type_family==UNARY_TYPE_FAMILY&&root_node.op.type!=ADD_TYPE) || detail::is_elementwise_function(root_node.op))
@@ -253,16 +253,16 @@ void evaluate_expression_traversal::call_before_expansion(isaac::math_expression
}
void evaluate_expression_traversal::call_after_expansion(math_expression const & math_expression, std::size_t root_idx) const
void evaluate_expression_traversal::call_after_expansion(expression_tree const & expression_tree, std::size_t root_idx) const
{
math_expression::node const & root_node = math_expression.tree()[root_idx];
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
if(root_node.op.type!=OPERATOR_FUSE)
str_+=")";
}
void evaluate_expression_traversal::operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf) const
void evaluate_expression_traversal::operator()(isaac::expression_tree const & expression_tree, std::size_t root_idx, leaf_t leaf) const
{
math_expression::node const & root_node = math_expression.tree()[root_idx];
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
mapping_type::key_type key = std::make_pair(root_idx, leaf);
if (leaf==PARENT_NODE_TYPE)
{
@@ -304,34 +304,34 @@ void evaluate_expression_traversal::operator()(isaac::math_expression const & ma
std::string evaluate(leaf_t leaf, std::map<std::string, std::string> const & accessors,
isaac::math_expression const & math_expression, std::size_t root_idx, mapping_type const & mapping)
isaac::expression_tree const & expression_tree, std::size_t root_idx, mapping_type const & mapping)
{
std::string res;
evaluate_expression_traversal traversal_functor(accessors, res, mapping);
math_expression::node const & root_node = math_expression.tree()[root_idx];
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
if (leaf==RHS_NODE_TYPE)
{
if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.rhs.node_index, traversal_functor, false);
traverse(expression_tree, root_node.rhs.node_index, traversal_functor, false);
else
traversal_functor(math_expression, root_idx, leaf);
traversal_functor(expression_tree, root_idx, leaf);
}
else if (leaf==LHS_NODE_TYPE)
{
if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.lhs.node_index, traversal_functor, false);
traverse(expression_tree, root_node.lhs.node_index, traversal_functor, false);
else
traversal_functor(math_expression, root_idx, leaf);
traversal_functor(expression_tree, root_idx, leaf);
}
else
traverse(math_expression, root_idx, traversal_functor, false);
traverse(expression_tree, root_idx, traversal_functor, false);
return res;
}
void evaluate(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
math_expression const & x, mapping_type const & mapping)
expression_tree const & x, mapping_type const & mapping)
{
stream << evaluate(leaf, accessors, x, x.root(), mapping) << std::endl;
}
@@ -341,7 +341,7 @@ process_traversal::process_traversal(std::map<std::string, std::string> const &
accessors_(accessors), stream_(stream), mapping_(mapping), already_processed_(already_processed)
{ }
void process_traversal::operator()(math_expression const & /*math_expression*/, std::size_t root_idx, leaf_t leaf) const
void process_traversal::operator()(expression_tree const & /*expression_tree*/, std::size_t root_idx, leaf_t leaf) const
{
mapping_type::const_iterator it = mapping_.find(std::make_pair(root_idx, leaf));
if (it!=mapping_.end())
@@ -362,41 +362,41 @@ void process_traversal::operator()(math_expression const & /*math_expression*/,
void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
isaac::math_expression const & math_expression, size_t root_idx, mapping_type const & mapping, std::set<std::string> & already_processed)
isaac::expression_tree const & expression_tree, size_t root_idx, mapping_type const & mapping, std::set<std::string> & already_processed)
{
process_traversal traversal_functor(accessors, stream, mapping, already_processed);
math_expression::node const & root_node = math_expression.tree()[root_idx];
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
if (leaf==RHS_NODE_TYPE)
{
if (root_node.rhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.rhs.node_index, traversal_functor, true);
traverse(expression_tree, root_node.rhs.node_index, traversal_functor, true);
else
traversal_functor(math_expression, root_idx, leaf);
traversal_functor(expression_tree, root_idx, leaf);
}
else if (leaf==LHS_NODE_TYPE)
{
if (root_node.lhs.subtype==COMPOSITE_OPERATOR_TYPE)
traverse(math_expression, root_node.lhs.node_index, traversal_functor, true);
traverse(expression_tree, root_node.lhs.node_index, traversal_functor, true);
else
traversal_functor(math_expression, root_idx, leaf);
traversal_functor(expression_tree, root_idx, leaf);
}
else
{
traverse(math_expression, root_idx, traversal_functor, true);
traverse(expression_tree, root_idx, traversal_functor, true);
}
}
void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
math_expression const & x, mapping_type const & mapping)
expression_tree const & x, mapping_type const & mapping)
{
std::set<std::string> processed;
process(stream, leaf, accessors, x, x.root(), mapping, processed);
}
void math_expression_representation_functor::append_id(char * & ptr, unsigned int val)
void expression_tree_representation_functor::append_id(char * & ptr, unsigned int val)
{
if (val==0)
*ptr++='0';
@@ -408,14 +408,14 @@ void math_expression_representation_functor::append_id(char * & ptr, unsigned in
}
}
//void math_expression_representation_functor::append(driver::Buffer const & h, numeric_type dtype, char prefix, bool is_assigned) const
//void expression_tree_representation_functor::append(driver::Buffer const & h, numeric_type dtype, char prefix, bool is_assigned) const
//{
// *ptr_++=prefix;
// *ptr_++=(char)dtype;
// append_id(ptr_, binder_.get(h, is_assigned));
//}
void math_expression_representation_functor::append(tree_node const & lhs_rhs, bool is_assigned) const
void expression_tree_representation_functor::append(tree_node const & lhs_rhs, bool is_assigned) const
{
if(lhs_rhs.subtype==DENSE_ARRAY_TYPE)
{
@@ -428,18 +428,18 @@ void math_expression_representation_functor::append(tree_node const & lhs_rhs, b
}
}
math_expression_representation_functor::math_expression_representation_functor(symbolic_binder & binder, char *& ptr) : binder_(binder), ptr_(ptr){ }
expression_tree_representation_functor::expression_tree_representation_functor(symbolic_binder & binder, char *& ptr) : binder_(binder), ptr_(ptr){ }
void math_expression_representation_functor::append(char*& p, const char * str) const
void expression_tree_representation_functor::append(char*& p, const char * str) const
{
std::size_t n = std::strlen(str);
std::memcpy(p, str, n);
p+=n;
}
void math_expression_representation_functor::operator()(isaac::math_expression const & math_expression, std::size_t root_idx, leaf_t leaf_t) const
void expression_tree_representation_functor::operator()(isaac::expression_tree const & expression_tree, std::size_t root_idx, leaf_t leaf_t) const
{
math_expression::node const & root_node = math_expression.tree()[root_idx];
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
append(root_node.lhs, detail::is_assignment(root_node.op));
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.subtype != COMPOSITE_OPERATOR_TYPE)

View File

@@ -28,16 +28,16 @@ base::parameters_type::parameters_type(unsigned int _simd_width, int_t _local_si
{ }
bool base::requires_fallback(math_expression const & expression)
bool base::requires_fallback(expression_tree const & expression)
{
for(math_expression::node const & node: expression.tree())
for(expression_tree::node const & node: expression.tree())
if( (node.lhs.subtype==DENSE_ARRAY_TYPE && (node.lhs.array->stride()[0]>1 || node.lhs.array->start()>0))
|| (node.rhs.subtype==DENSE_ARRAY_TYPE && (node.rhs.array->stride()[0]>1 || node.rhs.array->start()>0)))
return true;
return false;
}
int_t base::vector_size(math_expression::node const & node)
int_t base::vector_size(expression_tree::node const & node)
{
if (node.op.type==MATRIX_DIAG_TYPE)
return std::min<int_t>(node.lhs.array->shape()[0], node.lhs.array->shape()[1]);
@@ -50,7 +50,7 @@ int_t base::vector_size(math_expression::node const & node)
}
std::pair<int_t, int_t> base::matrix_size(math_expression::container_type const & tree, math_expression::node const & node)
std::pair<int_t, int_t> base::matrix_size(expression_tree::container_type const & tree, expression_tree::node const & node)
{
if (node.op.type==VDIAG_TYPE)
{
@@ -72,20 +72,20 @@ std::pair<int_t, int_t> base::matrix_size(math_expression::container_type const
base::base(binding_policy_t binding_policy) : binding_policy_(binding_policy)
{}
unsigned int base::lmem_usage(math_expression const &) const
unsigned int base::lmem_usage(expression_tree const &) const
{ return 0; }
unsigned int base::registers_usage(math_expression const &) const
unsigned int base::registers_usage(expression_tree const &) const
{ return 0; }
unsigned int base::temporary_workspace(math_expression const &) const
unsigned int base::temporary_workspace(expression_tree const &) const
{ return 0; }
base::~base()
{
}
std::string base::generate(std::string const & suffix, math_expression const & expression, driver::Device const & device)
std::string base::generate(std::string const & suffix, expression_tree const & expression, driver::Device const & device)
{
int err = is_invalid(expression, device);
if(err != 0)
@@ -104,7 +104,7 @@ std::string base::generate(std::string const & suffix, math_expression const &
}
template<class TType, class PType>
int base_impl<TType, PType>::is_invalid_impl(driver::Device const &, math_expression const &) const
int base_impl<TType, PType>::is_invalid_impl(driver::Device const &, expression_tree const &) const
{ return TEMPLATE_VALID; }
template<class TType, class PType>
@@ -124,7 +124,7 @@ std::shared_ptr<base> base_impl<TType, PType>::clone() const
{ return std::shared_ptr<base>(new TType(*dynamic_cast<TType const *>(this))); }
template<class TType, class PType>
int base_impl<TType, PType>::is_invalid(math_expression const & expressions, driver::Device const & device) const
int base_impl<TType, PType>::is_invalid(expression_tree const & expressions, driver::Device const & device) const
{
//Query device informations
size_t lmem_available = device.local_mem_size();

View File

@@ -26,14 +26,14 @@ elementwise_1d_parameters::elementwise_1d_parameters(unsigned int _simd_width,
}
int elementwise_1d::is_invalid_impl(driver::Device const &, math_expression const &) const
int elementwise_1d::is_invalid_impl(driver::Device const &, expression_tree const &) const
{
if (p_.fetching_policy==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
std::string elementwise_1d::generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mappings) const
std::string elementwise_1d::generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const & mappings) const
{
driver::backend_type backend = device.backend();
std::string _size_t = size_type(device);
@@ -43,7 +43,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, math_expre
std::string dtype = append_width("#scalartype",p_.simd_width);
std::vector<size_t> assigned_scalar = filter_nodes([](math_expression::node const & node) {
std::vector<size_t> assigned_scalar = filter_nodes([](expression_tree::node const & node) {
return detail::is_assignment(node.op) && node.lhs.subtype==DENSE_ARRAY_TYPE && node.lhs.array->shape().max()==1;
}, expressions, expressions.root(), true);
@@ -74,8 +74,8 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, math_expre
stream.inc_tab();
math_expression::container_type const & tree = expressions.tree();
std::vector<std::size_t> sfors = filter_nodes([](math_expression::node const & node){return node.op.type==SFOR_TYPE;}, expressions, expressions.root(), true);
expression_tree::container_type const & tree = expressions.tree();
std::vector<std::size_t> sfors = filter_nodes([](expression_tree::node const & node){return node.op.type==SFOR_TYPE;}, expressions, expressions.root(), true);
for(unsigned int i = 0 ; i < sfors.size() ; ++i)
{
@@ -100,7 +100,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, math_expre
if(sfors.size())
root = tree[sfors.back()].lhs.node_index;
std::vector<std::size_t> assigned = filter_nodes([](math_expression::node const & node){return detail::is_assignment(node.op);}, expressions, root, true);
std::vector<std::size_t> assigned = filter_nodes([](expression_tree::node const & node){return detail::is_assignment(node.op);}, expressions, root, true);
std::set<std::string> processed;
//Declares register to store results
@@ -182,14 +182,14 @@ elementwise_1d::elementwise_1d(unsigned int simd, unsigned int ls, unsigned int
{}
std::vector<int_t> elementwise_1d::input_sizes(math_expression const & expressions) const
std::vector<int_t> elementwise_1d::input_sizes(expression_tree const & expressions) const
{
return {expressions.shape().max()};
}
void elementwise_1d::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const & control)
{
math_expression const & expressions = control.x();
expression_tree const & expressions = control.x();
//Size
int_t size = input_sizes(expressions)[0];
//Fallback

View File

@@ -20,7 +20,7 @@ elementwise_2d_parameters::elementwise_2d_parameters(unsigned int _simd_width,
int elementwise_2d::is_invalid_impl(driver::Device const &, math_expression const &) const
int elementwise_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const
{
if (p_.simd_width>1)
return TEMPLATE_INVALID_SIMD_WIDTH;
@@ -29,7 +29,7 @@ int elementwise_2d::is_invalid_impl(driver::Device const &, math_expression cons
return TEMPLATE_VALID;
}
std::string elementwise_2d::generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mappings) const
std::string elementwise_2d::generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const & mappings) const
{
kernel_generation_stream stream;
std::string _size_t = size_type(device);
@@ -114,7 +114,7 @@ elementwise_2d::elementwise_2d(unsigned int simd, unsigned int ls1, unsigned int
base_impl<elementwise_2d, elementwise_2d_parameters>(elementwise_2d_parameters(simd, ls1, ls2, ng1, ng2, fetch), bind)
{}
std::vector<int_t> elementwise_2d::input_sizes(math_expression const & expression) const
std::vector<int_t> elementwise_2d::input_sizes(expression_tree const & expression) const
{
std::pair<int_t, int_t> size = matrix_size(expression.tree(), lhs_most(expression.tree(), expression.root()));
return {size.first, size.second};
@@ -122,7 +122,7 @@ std::vector<int_t> elementwise_2d::input_sizes(math_expression const & expressi
void elementwise_2d::enqueue(driver::CommandQueue & /*queue*/, driver::Program const & program, std::string const & suffix, base &, execution_handler const & control)
{
math_expression const & expressions = control.x();
expression_tree const & expressions = control.x();
std::string name = "elementwise_1d";
name +=suffix;
driver::Kernel kernel(program, name.c_str());

View File

@@ -27,7 +27,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
}
unsigned int matrix_product::lmem_usage(math_expression const & expression) const
unsigned int matrix_product::lmem_usage(expression_tree const & expression) const
{
numeric_type numeric_t = lhs_most(expression.tree(), expression.root()).lhs.dtype;
unsigned int N = 0;
@@ -36,7 +36,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
return N*size_of(numeric_t);
}
unsigned int matrix_product::registers_usage(math_expression const & expression) const
unsigned int matrix_product::registers_usage(expression_tree const & expression) const
{
numeric_type numeric_t = lhs_most(expression.tree(), expression.root()).lhs.dtype;
@@ -44,7 +44,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
return N*size_of(numeric_t);
}
unsigned int matrix_product::temporary_workspace(math_expression const & expressions) const
unsigned int matrix_product::temporary_workspace(expression_tree const & expressions) const
{
std::vector<int_t> MNK = input_sizes(expressions);
int_t M = MNK[0]; int_t N = MNK[1];
@@ -53,7 +53,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
return 0;
}
int matrix_product::is_invalid_impl(driver::Device const &, math_expression const &) const
int matrix_product::is_invalid_impl(driver::Device const &, expression_tree const &) const
{
// if(device.vendor()==driver::Device::Vendor::NVIDIA && p_.simd_width > 1)
// return TEMPLATE_INVALID_SIMD_WIDTH;
@@ -103,7 +103,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
return TEMPLATE_VALID;
}
std::string matrix_product::generate_impl(std::string const & suffix, math_expression const & expression, driver::Device const & device, mapping_type const &) const
std::string matrix_product::generate_impl(std::string const & suffix, expression_tree const & expression, driver::Device const & device, mapping_type const &) const
{
using std::string;
using tools::to_string;
@@ -651,9 +651,9 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
}
std::vector<int_t> matrix_product::infos(math_expression const & expression, symbolic::preset::matrix_product::args& arguments) const
std::vector<int_t> matrix_product::infos(expression_tree const & expression, symbolic::preset::matrix_product::args& arguments) const
{
math_expression::container_type const & array = expression.tree();
expression_tree::container_type const & array = expression.tree();
std::size_t root = expression.root();
arguments = symbolic::preset::matrix_product::check(array, root);
int_t M = arguments.C->array->shape()[0];
@@ -671,10 +671,10 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
else throw;
}
std::vector<int_t> matrix_product::input_sizes(math_expression const & expressions) const
std::vector<int_t> matrix_product::input_sizes(expression_tree const & expressions) const
{
symbolic::preset::matrix_product::args dummy;
return infos((math_expression&)expressions, dummy);
return infos((expression_tree&)expressions, dummy);
}
void matrix_product::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback_base, execution_handler const & control)
@@ -682,7 +682,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
using namespace tools;
matrix_product & fallback = (matrix_product&)fallback_base;
math_expression const & expressions = control.x();
expression_tree const & expressions = control.x();
symbolic::preset::matrix_product::args args;

View File

@@ -20,20 +20,20 @@ reduce_1d_parameters::reduce_1d_parameters(unsigned int _simd_width,
fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _group_size, 1, 2), num_groups(_num_groups), fetching_policy(_fetching_policy)
{ }
unsigned int reduce_1d::lmem_usage(math_expression const & x) const
unsigned int reduce_1d::lmem_usage(expression_tree const & x) const
{
numeric_type numeric_t= lhs_most(x.tree(), x.root()).lhs.dtype;
return p_.local_size_0*size_of(numeric_t);
}
int reduce_1d::is_invalid_impl(driver::Device const &, math_expression const &) const
int reduce_1d::is_invalid_impl(driver::Device const &, expression_tree const &) const
{
if (p_.fetching_policy==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
unsigned int reduce_1d::temporary_workspace(math_expression const &) const
unsigned int reduce_1d::temporary_workspace(expression_tree const &) const
{
if(p_.num_groups > 1)
return p_.num_groups;
@@ -65,7 +65,7 @@ inline void reduce_1d::reduce_1d_local_memory(kernel_generation_stream & stream,
stream << "}" << std::endl;
}
std::string reduce_1d::generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mapping) const
std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, mapping_type const & mapping) const
{
kernel_generation_stream stream;
@@ -87,7 +87,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, math_expression
unsigned int offset = 0;
for (unsigned int k = 0; k < N; ++k)
{
numeric_type dtype = lhs_most(exprs[k]->math_expression().tree(), exprs[k]->math_expression().root()).lhs.dtype;
numeric_type dtype = lhs_most(exprs[k]->expression_tree().tree(), exprs[k]->expression_tree().root()).lhs.dtype;
std::string sdtype = to_string(dtype);
if (exprs[k]->is_index_reduction())
{
@@ -298,7 +298,7 @@ reduce_1d::reduce_1d(unsigned int simd, unsigned int ls, unsigned int ng,
base_impl<reduce_1d, reduce_1d_parameters>(reduce_1d_parameters(simd,ls,ng,fetch), bind)
{}
std::vector<int_t> reduce_1d::input_sizes(math_expression const & x) const
std::vector<int_t> reduce_1d::input_sizes(expression_tree const & x) const
{
std::vector<size_t> reduce_1ds_idx = filter_nodes(&is_reduce_1d, x, x.root(), false);
int_t N = vector_size(lhs_most(x.tree(), reduce_1ds_idx[0]));
@@ -307,7 +307,7 @@ std::vector<int_t> reduce_1d::input_sizes(math_expression const & x) const
void reduce_1d::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const & control)
{
math_expression const & x = control.x();
expression_tree const & x = control.x();
//Preprocessing
int_t size = input_sizes(x)[0];
@@ -319,7 +319,7 @@ void reduce_1d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
return;
}
std::vector<math_expression::node const *> reduce_1ds;
std::vector<expression_tree::node const *> reduce_1ds;
std::vector<size_t> reduce_1ds_idx = filter_nodes(&is_reduce_1d, x, x.root(), false);
for (size_t idx: reduce_1ds_idx)
reduce_1ds.push_back(&x.tree()[idx]);

View File

@@ -22,19 +22,19 @@ reduce_2d_parameters::reduce_2d_parameters(unsigned int _simd_width,
num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetch_policy(_fetch_policy) { }
int reduce_2d::is_invalid_impl(driver::Device const &, math_expression const &) const
int reduce_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const
{
if (p_.fetch_policy==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
unsigned int reduce_2d::lmem_usage(const math_expression&) const
unsigned int reduce_2d::lmem_usage(const expression_tree&) const
{
return (p_.local_size_0+1)*p_.local_size_1;
}
unsigned int reduce_2d::temporary_workspace(math_expression const & expressions) const
unsigned int reduce_2d::temporary_workspace(expression_tree const & expressions) const
{
std::vector<int_t> MN = input_sizes(expressions);
int_t M = MN[0];
@@ -43,7 +43,7 @@ unsigned int reduce_2d::temporary_workspace(math_expression const & expressions)
return 0;
}
std::string reduce_2d::generate_impl(std::string const & suffix, math_expression const & expression, driver::Device const & device, mapping_type const & mapping) const
std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree const & expression, driver::Device const & device, mapping_type const & mapping) const
{
using tools::to_string;
@@ -67,7 +67,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, math_expression
unsigned int offset = 0;
for (const auto & e : reduce_1ds)
{
numeric_type dtype = lhs_most(e->math_expression().tree(), e->math_expression().root()).lhs.dtype;
numeric_type dtype = lhs_most(e->expression_tree().tree(), e->expression_tree().root()).lhs.dtype;
std::string sdtype = to_string(dtype);
if (e->is_index_reduction())
{
@@ -360,7 +360,7 @@ reduce_2d::reduce_2d(reduce_2d::parameters_type const & parameters,
base_impl<reduce_2d, reduce_2d_parameters>(parameters, binding_policy),
reduce_1d_type_(rtype){ }
std::vector<int_t> reduce_2d::input_sizes(math_expression const & expression) const
std::vector<int_t> reduce_2d::input_sizes(expression_tree const & expression) const
{
std::vector<std::size_t> idx = filter_nodes(&is_reduce_1d, expression, expression.root(), false);
std::pair<int_t, int_t> MN = matrix_size(expression.tree(), lhs_most(expression.tree(), idx[0]));
@@ -371,10 +371,10 @@ std::vector<int_t> reduce_2d::input_sizes(math_expression const & expression) co
void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const & control)
{
math_expression const & expression = control.x();
expression_tree const & expression = control.x();
std::vector<int_t> MN = input_sizes(expression);
std::vector<math_expression::node const *> reduce_1ds;
std::vector<expression_tree::node const *> reduce_1ds;
std::vector<size_t> reduce_1ds_idx = filter_nodes(&is_reduce_1d, expression, expression.root(), false);
for (size_t idx : reduce_1ds_idx)
reduce_1ds.push_back(&expression.tree()[idx]);

View File

@@ -12,7 +12,7 @@ namespace templates
{
//Generate
inline std::string generate_arguments(std::string const &, driver::Device const & device, mapping_type const & mappings, math_expression const & expressions)
inline std::string generate_arguments(std::string const &, driver::Device const & device, mapping_type const & mappings, expression_tree const & expressions)
{
std::string kwglobal = Global(device.backend()).get();
std::string _size_t = size_type(device);
@@ -90,9 +90,9 @@ public:
}
}
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t leaf_t) const
void operator()(isaac::expression_tree const & expression_tree, size_t root_idx, leaf_t leaf_t) const
{
math_expression::node const & root_node = math_expression.tree()[root_idx];
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
set_arguments(root_node.lhs, detail::is_assignment(root_node.op));
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.subtype != COMPOSITE_OPERATOR_TYPE)
@@ -106,7 +106,7 @@ private:
driver::Kernel & kernel_;
};
inline void set_arguments(math_expression const & expression, driver::Kernel & kernel, unsigned int & current_arg, binding_policy_t binding_policy)
inline void set_arguments(expression_tree const & expression, driver::Kernel & kernel, unsigned int & current_arg, binding_policy_t binding_policy)
{
std::unique_ptr<symbolic_binder> binder;
if (binding_policy==BIND_SEQUENTIAL)

View File

@@ -12,18 +12,18 @@ namespace templates
class map_functor : public traversal_functor
{
numeric_type get_numeric_type(isaac::math_expression const * math_expression, size_t root_idx) const
numeric_type get_numeric_type(isaac::expression_tree const * expression_tree, size_t root_idx) const
{
math_expression::node const * root_node = &math_expression->tree()[root_idx];
expression_tree::node const * root_node = &expression_tree->tree()[root_idx];
while (root_node->lhs.dtype==INVALID_NUMERIC_TYPE)
root_node = &math_expression->tree()[root_node->lhs.node_index];
root_node = &expression_tree->tree()[root_node->lhs.node_index];
return root_node->lhs.dtype;
}
template<class T>
std::shared_ptr<mapped_object> binary_leaf(isaac::math_expression const * math_expression, size_t root_idx, mapping_type const * mapping) const
std::shared_ptr<mapped_object> binary_leaf(isaac::expression_tree const * expression_tree, size_t root_idx, mapping_type const * mapping) const
{
return std::shared_ptr<mapped_object>(new T(to_string(math_expression->dtype()), binder_.get(), mapped_object::node_info(mapping, math_expression, root_idx)));
return std::shared_ptr<mapped_object>(new T(to_string(expression_tree->dtype()), binder_.get(), mapped_object::node_info(mapping, expression_tree, root_idx)));
}
std::shared_ptr<mapped_object> create(numeric_type dtype, values_holder) const
@@ -60,10 +60,10 @@ public:
{
}
void operator()(isaac::math_expression const & math_expression, size_t root_idx, leaf_t leaf_t) const
void operator()(isaac::expression_tree const & expression_tree, size_t root_idx, leaf_t leaf_t) const
{
mapping_type::key_type key(root_idx, leaf_t);
math_expression::node const & root_node = math_expression.tree()[root_idx];
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
if (leaf_t == LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
mapping_.insert(mapping_type::value_type(key, create(root_node.lhs, detail::is_assignment(root_node.op))));
@@ -72,25 +72,25 @@ public:
else if ( leaf_t== PARENT_NODE_TYPE)
{
if (root_node.op.type==VDIAG_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_vdiag>(&math_expression, root_idx, &mapping_)));
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_vdiag>(&expression_tree, root_idx, &mapping_)));
else if (root_node.op.type==MATRIX_DIAG_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_diag>(&math_expression, root_idx, &mapping_)));
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_diag>(&expression_tree, root_idx, &mapping_)));
else if (root_node.op.type==MATRIX_ROW_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&math_expression, root_idx, &mapping_)));
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&expression_tree, root_idx, &mapping_)));
else if (root_node.op.type==MATRIX_COLUMN_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&math_expression, root_idx, &mapping_)));
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&expression_tree, root_idx, &mapping_)));
else if(root_node.op.type==ACCESS_INDEX_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_array_access>(&math_expression, root_idx, &mapping_)));
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_array_access>(&expression_tree, root_idx, &mapping_)));
else if (detail::is_scalar_reduce_1d(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_1d>(&math_expression, root_idx, &mapping_)));
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_1d>(&expression_tree, root_idx, &mapping_)));
else if (detail::is_vector_reduce_1d(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_2d>(&math_expression, root_idx, &mapping_)));
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_reduce_2d>(&expression_tree, root_idx, &mapping_)));
else if (root_node.op.type_family == MATRIX_PRODUCT_TYPE_FAMILY)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_product>(&math_expression, root_idx, &mapping_)));
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_product>(&expression_tree, root_idx, &mapping_)));
else if (root_node.op.type == REPEAT_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&math_expression, root_idx, &mapping_)));
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&expression_tree, root_idx, &mapping_)));
else if (root_node.op.type == OUTER_PROD_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&math_expression, root_idx, &mapping_)));
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&expression_tree, root_idx, &mapping_)));
else if (detail::is_cast(root_node.op))
mapping_.insert(mapping_type::value_type(key, std::shared_ptr<mapped_object>(new mapped_cast(root_node.op.type, binder_.get()))));
}

View File

@@ -55,7 +55,7 @@ inline std::string neutral_element(op_element const & op, driver::backend_type b
}
}
inline bool is_reduce_1d(math_expression::node const & node)
inline bool is_reduce_1d(expression_tree::node const & node)
{
return node.op.type_family==VECTOR_DOT_TYPE_FAMILY
|| node.op.type_family==COLUMNS_DOT_TYPE_FAMILY

View File

@@ -39,7 +39,7 @@ driver::Program const & profiles::value_type::init(execution_handler const & exp
char* ptr = program_name;
bind_independent binder;
traverse(expression.x(), expression.x().root(), math_expression_representation_functor(binder, ptr),true);
traverse(expression.x(), expression.x().root(), expression_tree_representation_functor(binder, ptr),true);
*ptr='\0';
pname = std::string(program_name);

View File

@@ -105,12 +105,12 @@ namespace isaac
}
/** @brief Parses the breakpoints for a given expression tree */
static void parse(math_expression::container_type&array, size_t idx,
static void parse(expression_tree::container_type&array, size_t idx,
breakpoints_t & breakpoints,
expression_type & final_type,
bool is_first = true)
{
math_expression::node & node = array[idx];
expression_tree::node & node = array[idx];
auto ng1 = [](shape_t const & shape){ size_t res = 0 ; for(size_t i = 0 ; i < shape.size() ; ++i) res += (shape[i] > 1); return res;};
//Left
@@ -146,14 +146,14 @@ namespace isaac
}
}
/** @brief Executes a math_expression on the given models map*/
/** @brief Executes a expression_tree on the given models map*/
void execute(execution_handler const & c, profiles::map_type & profiles)
{
math_expression expression = c.x();
expression_tree expression = c.x();
driver::Context const & context = expression.context();
size_t rootidx = expression.root();
math_expression::container_type & tree = const_cast<math_expression::container_type &>(expression.tree());
math_expression::node root_save = tree[rootidx];
expression_tree::container_type & tree = const_cast<expression_tree::container_type &>(expression.tree());
expression_tree::node root_save = tree[rootidx];
//Todo: technically the datatype should be per temporary
numeric_type dtype = expression.dtype();
@@ -186,8 +186,8 @@ namespace isaac
for(detail::breakpoints_t::iterator it = breakpoints.begin() ; it != breakpoints.end() ; ++it)
{
std::shared_ptr<profiles::value_type> const & profile = profiles[std::make_pair(it->first, dtype)];
math_expression::node const & node = tree[it->second->node_index];
math_expression::node const & lmost = lhs_most(tree, node);
expression_tree::node const & node = tree[it->second->node_index];
expression_tree::node const & lmost = lhs_most(tree, node);
//Creates temporary
std::shared_ptr<array> tmp;
@@ -217,7 +217,7 @@ namespace isaac
profile->execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
tree[rootidx] = root_save;
//Incorporates the temporary within, the math_expression
//Incorporates the temporary within, the expression_tree
fill(*it->second, (array&)*tmp);
}
}

View File

@@ -49,7 +49,7 @@ op_element::op_element() {}
op_element::op_element(operation_type_family const & _type_family, operation_type const & _type) : type_family(_type_family), type(_type){}
//
math_expression::math_expression(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op)
expression_tree::expression_tree(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op)
: tree_(1), root_(0), context_(NULL), dtype_(INVALID_NUMERIC_TYPE), shape_(1)
{
fill(tree_[0].lhs, lhs);
@@ -57,7 +57,7 @@ math_expression::math_expression(for_idx_t const &lhs, for_idx_t const &rhs, con
fill(tree_[0].rhs, rhs);
}
math_expression::math_expression(for_idx_t const &lhs, value_scalar const &rhs, const op_element &op, const numeric_type &dtype)
expression_tree::expression_tree(for_idx_t const &lhs, value_scalar const &rhs, const op_element &op, const numeric_type &dtype)
: tree_(1), root_(0), context_(NULL), dtype_(dtype), shape_(1)
{
fill(tree_[0].lhs, lhs);
@@ -65,7 +65,7 @@ math_expression::math_expression(for_idx_t const &lhs, value_scalar const &rhs,
fill(tree_[0].rhs, rhs);
}
math_expression::math_expression(value_scalar const &lhs, for_idx_t const &rhs, const op_element &op, const numeric_type &dtype)
expression_tree::expression_tree(value_scalar const &lhs, for_idx_t const &rhs, const op_element &op, const numeric_type &dtype)
: tree_(1), root_(0), context_(NULL), dtype_(dtype), shape_(1)
{
fill(tree_[0].lhs, lhs);
@@ -75,11 +75,11 @@ math_expression::math_expression(value_scalar const &lhs, for_idx_t const &rhs,
//math_expression(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op);
//math_expression(for_idx_t const &lhs, value_scalar const &rhs, const op_element &op, const numeric_type &dtype);
//expression_tree(for_idx_t const &lhs, for_idx_t const &rhs, const op_element &op);
//expression_tree(for_idx_t const &lhs, value_scalar const &rhs, const op_element &op, const numeric_type &dtype);
template<class LT, class RT>
math_expression::math_expression(LT const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape) :
expression_tree::expression_tree(LT const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape) :
tree_(1), root_(0), context_(&context), dtype_(dtype), shape_(shape)
{
fill(tree_[0].lhs, lhs);
@@ -88,7 +88,7 @@ math_expression::math_expression(LT const & lhs, RT const & rhs, op_element cons
}
template<class RT>
math_expression::math_expression(math_expression const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape) :
expression_tree::expression_tree(expression_tree const & lhs, RT const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape) :
tree_(lhs.tree_.size() + 1), root_(tree_.size()-1), context_(&context), dtype_(dtype), shape_(shape)
{
std::copy(lhs.tree_.begin(), lhs.tree_.end(), tree_.begin());
@@ -98,7 +98,7 @@ math_expression::math_expression(math_expression const & lhs, RT const & rhs, op
}
template<class LT>
math_expression::math_expression(LT const & lhs, math_expression const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape) :
expression_tree::expression_tree(LT const & lhs, expression_tree const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape) :
tree_(rhs.tree_.size() + 1), root_(tree_.size() - 1), context_(&context), dtype_(dtype), shape_(shape)
{
std::copy(rhs.tree_.begin(), rhs.tree_.end(), tree_.begin());
@@ -107,7 +107,7 @@ math_expression::math_expression(LT const & lhs, math_expression const & rhs, op
fill(tree_[root_].rhs, rhs.root_);
}
math_expression::math_expression(math_expression const & lhs, math_expression const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape):
expression_tree::expression_tree(expression_tree const & lhs, expression_tree const & rhs, op_element const & op, driver::Context const & context, numeric_type const & dtype, shape_t const & shape):
tree_(lhs.tree_.size() + rhs.tree_.size() + 1), root_(tree_.size()-1), context_(&context), dtype_(dtype), shape_(shape)
{
std::size_t lsize = lhs.tree_.size();
@@ -123,84 +123,77 @@ math_expression::math_expression(math_expression const & lhs, math_expression co
root_ = tree_.size() - 1;
}
template math_expression::math_expression(math_expression const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(math_expression const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(math_expression const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(math_expression const &, for_idx_t const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(expression_tree const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(expression_tree const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(expression_tree const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(expression_tree const &, for_idx_t const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(value_scalar const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(value_scalar const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(value_scalar const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(value_scalar const &, math_expression const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(value_scalar const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(value_scalar const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(value_scalar const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(value_scalar const &, expression_tree const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(invalid_node const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(invalid_node const &, math_expression const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(invalid_node const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(invalid_node const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(invalid_node const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(invalid_node const &, expression_tree const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(invalid_node const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(invalid_node const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(array_base const &, math_expression const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(array_base const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(array_base const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(array_base const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(array_base const &, for_idx_t const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(array_base const &, expression_tree const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(array_base const &, value_scalar const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(array_base const &, invalid_node const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(array_base const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(array_base const &, for_idx_t const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(for_idx_t const &, math_expression const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template math_expression::math_expression(for_idx_t const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(for_idx_t const &, expression_tree const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
template expression_tree::expression_tree(for_idx_t const &, array_base const &, op_element const &, driver::Context const &, numeric_type const &, shape_t const &);
math_expression::container_type & math_expression::tree()
expression_tree::container_type & expression_tree::tree()
{ return tree_; }
math_expression::container_type const & math_expression::tree() const
expression_tree::container_type const & expression_tree::tree() const
{ return tree_; }
std::size_t math_expression::root() const
std::size_t expression_tree::root() const
{ return root_; }
driver::Context const & math_expression::context() const
driver::Context const & expression_tree::context() const
{ return *context_; }
numeric_type const & math_expression::dtype() const
numeric_type const & expression_tree::dtype() const
{ return dtype_; }
shape_t math_expression::shape() const
shape_t expression_tree::shape() const
{ return shape_; }
int_t math_expression::dim() const
int_t expression_tree::dim() const
{ return (int_t)shape_.size(); }
//math_expression& math_expression::reshape(int_t size1, int_t size2)
//{
// assert(size1*size2==prod(shape_));
// shape_ = size4(size1, size2);
// return *this;
//}
expression_tree expression_tree::operator-()
{ return expression_tree(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), *context_, dtype_, shape_); }
math_expression math_expression::operator-()
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, SUB_TYPE), *context_, dtype_, shape_); }
math_expression math_expression::operator!()
{ return math_expression(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), *context_, INT_TYPE, shape_); }
expression_tree expression_tree::operator!()
{ return expression_tree(*this, invalid_node(), op_element(UNARY_TYPE_FAMILY, NEGATE_TYPE), *context_, INT_TYPE, shape_); }
//
math_expression::node const & lhs_most(math_expression::container_type const & array, math_expression::node const & init)
expression_tree::node const & lhs_most(expression_tree::container_type const & array, expression_tree::node const & init)
{
math_expression::node const * current = &init;
expression_tree::node const * current = &init;
while (current->lhs.subtype==COMPOSITE_OPERATOR_TYPE)
current = &array[current->lhs.node_index];
return *current;
}
math_expression::node const & lhs_most(math_expression::container_type const & array, size_t root)
expression_tree::node const & lhs_most(expression_tree::container_type const & array, size_t root)
{ return lhs_most(array, array[root]); }
//
math_expression for_idx_t::operator=(value_scalar const & r) const { return math_expression(*this, r, op_element(BINARY_TYPE_FAMILY,ASSIGN_TYPE), r.dtype()); }
math_expression for_idx_t::operator=(math_expression const & r) const { return math_expression(*this, r, op_element(BINARY_TYPE_FAMILY,ASSIGN_TYPE), r.context(), r.dtype(), r.shape()); }
expression_tree for_idx_t::operator=(value_scalar const & r) const { return expression_tree(*this, r, op_element(BINARY_TYPE_FAMILY,ASSIGN_TYPE), r.dtype()); }
expression_tree for_idx_t::operator=(expression_tree const & r) const { return expression_tree(*this, r, op_element(BINARY_TYPE_FAMILY,ASSIGN_TYPE), r.context(), r.dtype(), r.shape()); }
math_expression for_idx_t::operator+=(value_scalar const & r) const { return *this = *this + r; }
math_expression for_idx_t::operator-=(value_scalar const & r) const { return *this = *this - r; }
math_expression for_idx_t::operator*=(value_scalar const & r) const { return *this = *this * r; }
math_expression for_idx_t::operator/=(value_scalar const & r) const { return *this = *this / r; }
expression_tree for_idx_t::operator+=(value_scalar const & r) const { return *this = *this + r; }
expression_tree for_idx_t::operator-=(value_scalar const & r) const { return *this = *this - r; }
expression_tree for_idx_t::operator*=(value_scalar const & r) const { return *this = *this * r; }
expression_tree for_idx_t::operator/=(value_scalar const & r) const { return *this = *this / r; }
}

View File

@@ -30,7 +30,7 @@ inline std::string to_string(tree_node const & e)
return tools::to_string(e.subtype);
}
inline std::ostream & operator<<(std::ostream & os, math_expression::node const & s_node)
inline std::ostream & operator<<(std::ostream & os, expression_tree::node const & s_node)
{
os << "LHS: " << to_string(s_node.lhs) << "|" << s_node.lhs.dtype << ", "
<< "OP: " << s_node.op.type_family << " | " << s_node.op.type << ", "
@@ -42,11 +42,11 @@ inline std::ostream & operator<<(std::ostream & os, math_expression::node const
namespace detail
{
/** @brief Recursive worker routine for printing a whole math_expression */
inline void print_node(std::ostream & os, isaac::math_expression const & s, size_t node_index, size_t indent = 0)
/** @brief Recursive worker routine for printing a whole expression_tree */
inline void print_node(std::ostream & os, isaac::expression_tree const & s, size_t node_index, size_t indent = 0)
{
math_expression::container_type const & nodes = s.tree();
math_expression::node const & current_node = nodes[node_index];
expression_tree::container_type const & nodes = s.tree();
expression_tree::node const & current_node = nodes[node_index];
for (size_t i=0; i<indent; ++i)
os << " ";
@@ -61,7 +61,7 @@ namespace detail
}
}
std::string to_string(isaac::math_expression const & s)
std::string to_string(isaac::expression_tree const & s)
{
std::ostringstream os;
detail::print_node(os, s, s.root());

View File

@@ -9,7 +9,7 @@ namespace symbolic
namespace preset
{
void matrix_product::handle_node(math_expression::container_type const & tree, size_t rootidx, args & a)
void matrix_product::handle_node(expression_tree::container_type const & tree, size_t rootidx, args & a)
{
//Matrix-Matrix product node
if(tree[rootidx].op.type_family==MATRIX_PRODUCT_TYPE_FAMILY)
@@ -46,7 +46,7 @@ void matrix_product::handle_node(math_expression::container_type const & tree, s
}
}
matrix_product::args matrix_product::check(math_expression::container_type const & tree, size_t rootidx)
matrix_product::args matrix_product::check(expression_tree::container_type const & tree, size_t rootidx)
{
tree_node const * assigned = &tree[rootidx].lhs;
numeric_type dtype = assigned->dtype;

View File

@@ -50,7 +50,7 @@ INSTANTIATE(double)
value_scalar::value_scalar(numeric_type dtype) : dtype_(dtype) {}
value_scalar::value_scalar(scalar const & s) : dtype_(s.dtype()) { s.inject(values_); }
value_scalar::value_scalar(math_expression const &expr) : dtype_(expr.dtype()) { scalar(expr).inject(values_); }
value_scalar::value_scalar(expression_tree const &expr) : dtype_(expr.dtype()) { scalar(expr).inject(values_); }
value_scalar::value_scalar(values_holder values, numeric_type dtype) : values_(values), dtype_(dtype) {}
values_holder value_scalar::values() const
{ return values_; }

View File

@@ -18,7 +18,7 @@ extern "C"
isaac::driver::backend::release();
}
void execute(sc::math_expression const & operation, sc::driver::Context const & context,
void execute(sc::expression_tree const & operation, sc::driver::Context const & context,
cl_uint numCommandQueues, cl_command_queue *commandQueues,
cl_uint numEventsInWaitList, const cl_event *eventWaitList,
cl_event *events)

View File

@@ -252,15 +252,15 @@ void export_core()
#undef INSTANTIATE
bp::enum_<sc::expression_type>("operations")
MAP_ENUM(AXPY_TYPE, sc)
MAP_ENUM(GER_TYPE, sc)
MAP_ENUM(DOT_TYPE, sc)
MAP_ENUM(GEMV_N_TYPE, sc)
MAP_ENUM(GEMV_T_TYPE, sc)
MAP_ENUM(GEMM_NN_TYPE, sc)
MAP_ENUM(GEMM_TN_TYPE, sc)
MAP_ENUM(GEMM_NT_TYPE, sc)
MAP_ENUM(GEMM_TT_TYPE, sc);
MAP_ENUM(ELEMENTWISE_1D, sc)
MAP_ENUM(ELEMENTWISE_2D, sc)
MAP_ENUM(REDUCE_1D, sc)
MAP_ENUM(REDUCE_2D_ROWS, sc)
MAP_ENUM(REDUCE_2D_COLS, sc)
MAP_ENUM(MATRIX_PRODUCT_NN, sc)
MAP_ENUM(MATRIX_PRODUCT_TN, sc)
MAP_ENUM(MATRIX_PRODUCT_NT, sc)
MAP_ENUM(MATRIX_PRODUCT_TT, sc);
#define ADD_SCALAR_HANDLING(OP)\
.def(bp::self OP int())\
@@ -276,7 +276,7 @@ void export_core()
.def(bp::self OP bp::self)\
ADD_SCALAR_HANDLING(OP)
bp::class_<sc::math_expression >("math_expression", bp::no_init)
bp::class_<sc::expression_tree >("expression_tree", bp::no_init)
ADD_ARRAY_OPERATOR(+)
ADD_ARRAY_OPERATOR(-)
ADD_ARRAY_OPERATOR(*)
@@ -287,7 +287,7 @@ void export_core()
ADD_ARRAY_OPERATOR(<=)
ADD_ARRAY_OPERATOR(==)
ADD_ARRAY_OPERATOR(!=)
.add_property("context", bp::make_function(&sc::math_expression::context, bp::return_internal_reference<>()))
.add_property("context", bp::make_function(&sc::expression_tree::context, bp::return_internal_reference<>()))
.def(bp::self_ns::abs(bp::self))
// .def(bp::self_ns::pow(bp::self))
;
@@ -295,8 +295,8 @@ void export_core()
#define ADD_ARRAY_OPERATOR(OP) \
.def(bp::self OP bp::self)\
.def(bp::self OP bp::other<sc::math_expression>())\
.def(bp::other<sc::math_expression>() OP bp::self) \
.def(bp::self OP bp::other<sc::expression_tree>())\
.def(bp::other<sc::expression_tree>() OP bp::self) \
ADD_SCALAR_HANDLING(OP)
bp::class_<sc::array_base, boost::noncopyable>("array_base", bp::no_init)
@@ -322,7 +322,7 @@ void export_core()
bp::class_<sc::array,std::shared_ptr<sc::array>, bp::bases<sc::array_base> >
( "array", bp::no_init)
.def("__init__", bp::make_constructor(detail::create_array, bp::default_call_policies(), (bp::arg("obj"), bp::arg("dtype") = bp::scope().attr("float32"), bp::arg("context")= bp::object())))
.def(bp::init<sc::math_expression>())
.def(bp::init<sc::expression_tree>())
;
bp::class_<sc::view, bp::bases<sc::array_base> >
@@ -338,15 +338,15 @@ void export_core()
bp::def("empty", &detail::create_empty_array, (bp::arg("shape"), bp::arg("dtype") = bp::scope().attr("float32"), bp::arg("context")=bp::object()));
//Assign
bp::def("assign", static_cast<sc::math_expression (*)(sc::array_base const &, sc::array_base const &)>(&sc::assign));\
bp::def("assign", static_cast<sc::math_expression (*)(sc::array_base const &, sc::math_expression const &)>(&sc::assign));\
bp::def("assign", static_cast<sc::expression_tree (*)(sc::array_base const &, sc::array_base const &)>(&sc::assign));\
bp::def("assign", static_cast<sc::expression_tree (*)(sc::array_base const &, sc::expression_tree const &)>(&sc::assign));\
//Binary
#define MAP_FUNCTION(name) \
bp::def(#name, static_cast<sc::math_expression (*)(sc::array_base const &, sc::array_base const &)>(&sc::name));\
bp::def(#name, static_cast<sc::math_expression (*)(sc::math_expression const &, sc::array_base const &)>(&sc::name));\
bp::def(#name, static_cast<sc::math_expression (*)(sc::array_base const &, sc::math_expression const &)>(&sc::name));\
bp::def(#name, static_cast<sc::math_expression (*)(sc::math_expression const &, sc::math_expression const &)>(&sc::name));
bp::def(#name, static_cast<sc::expression_tree (*)(sc::array_base const &, sc::array_base const &)>(&sc::name));\
bp::def(#name, static_cast<sc::expression_tree (*)(sc::expression_tree const &, sc::array_base const &)>(&sc::name));\
bp::def(#name, static_cast<sc::expression_tree (*)(sc::array_base const &, sc::expression_tree const &)>(&sc::name));\
bp::def(#name, static_cast<sc::expression_tree (*)(sc::expression_tree const &, sc::expression_tree const &)>(&sc::name));
MAP_FUNCTION(maximum)
MAP_FUNCTION(minimum)
@@ -356,8 +356,8 @@ void export_core()
//Unary
#define MAP_FUNCTION(name) \
bp::def(#name, static_cast<sc::math_expression (*)(sc::array_base const &)>(&sc::name));\
bp::def(#name, static_cast<sc::math_expression (*)(sc::math_expression const &)>(&sc::name));
bp::def(#name, static_cast<sc::expression_tree (*)(sc::array_base const &)>(&sc::name));\
bp::def(#name, static_cast<sc::expression_tree (*)(sc::expression_tree const &)>(&sc::name));
bp::def("zeros", &detail::create_zeros_array, (bp::arg("shape"), bp::arg("dtype") = bp::scope().attr("float32"), bp::arg("context")=bp::object()));
@@ -382,8 +382,8 @@ void export_core()
/*--- Reduction operators----*/
//---------------------------------------
#define MAP_FUNCTION(name) \
bp::def(#name, static_cast<sc::math_expression (*)(sc::array_base const &, sc::int_t)>(&sc::name));\
bp::def(#name, static_cast<sc::math_expression (*)(sc::math_expression const &, sc::int_t)>(&sc::name));
bp::def(#name, static_cast<sc::expression_tree (*)(sc::array_base const &, sc::int_t)>(&sc::name));\
bp::def(#name, static_cast<sc::expression_tree (*)(sc::expression_tree const &, sc::int_t)>(&sc::name));
MAP_FUNCTION(sum)
MAP_FUNCTION(max)

View File

@@ -62,7 +62,7 @@ namespace detail
std::shared_ptr<sc::driver::Context> make_context(sc::driver::Device const & dev)
{ return std::shared_ptr<sc::driver::Context>(new sc::driver::Context(dev)); }
bp::object enqueue(sc::math_expression const & expression, unsigned int queue_id, bp::list dependencies, bool tune, int label, std::string const & program_name, bool force_recompile)
bp::object enqueue(sc::expression_tree const & expression, unsigned int queue_id, bp::list dependencies, bool tune, int label, std::string const & program_name, bool force_recompile)
{
std::list<sc::driver::Event> events;
std::vector<sc::driver::Event> cdependencies = tools::to_vector<sc::driver::Event>(dependencies);
@@ -70,7 +70,7 @@ namespace detail
sc::execution_options_type execution_options(queue_id, &events, &cdependencies);
sc::dispatcher_options_type dispatcher_options(tune, label);
sc::compilation_options_type compilation_options(program_name, force_recompile);
sc::math_expression::container_type::value_type root = expression.tree()[expression.root()];
sc::expression_tree::container_type::value_type root = expression.tree()[expression.root()];
if(sc::detail::is_assignment(root.op))
{
sc::execute(sc::execution_handler(expression, execution_options, dispatcher_options, compilation_options), isaac::profiles::get(execution_options.queue(expression.context())));

View File

@@ -13,7 +13,7 @@ namespace tpt = isaac::templates;
namespace detail
{
bp::list input_sizes(tpt::base & temp, sc::math_expression const & tree)
bp::list input_sizes(tpt::base & temp, sc::expression_tree const & tree)
{
std::vector<isaac::int_t> tmp = temp.input_sizes(tree);
return tools::to_list(tmp.begin(), tmp.end());