2015-02-01 17:15:41 -05:00
|
|
|
#ifndef _ATIDLAS_SYMBOLIC_EXPRESSION_H
|
|
|
|
#define _ATIDLAS_SYMBOLIC_EXPRESSION_H
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
#include <list>
|
2015-02-01 17:15:41 -05:00
|
|
|
#include <CL/cl.hpp>
|
2015-01-12 13:20:53 -05:00
|
|
|
#include "atidlas/types.h"
|
|
|
|
#include "atidlas/value_scalar.h"
|
2015-02-04 22:06:15 -05:00
|
|
|
#include <memory>
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
namespace atidlas
|
|
|
|
{
|
|
|
|
|
|
|
|
class array;
|
2015-01-17 10:48:02 -05:00
|
|
|
class repeat_infos;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
/** @brief Optimization enum for grouping operations into unary or binary operations. Just for optimization of lookups. */
|
|
|
|
enum operation_node_type_family
|
|
|
|
{
|
|
|
|
OPERATOR_INVALID_TYPE_FAMILY = 0,
|
|
|
|
|
|
|
|
// BLAS1-type
|
|
|
|
OPERATOR_UNARY_TYPE_FAMILY,
|
|
|
|
OPERATOR_BINARY_TYPE_FAMILY,
|
|
|
|
OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY,
|
|
|
|
|
|
|
|
// BLAS2-type
|
|
|
|
OPERATOR_ROWS_REDUCTION_TYPE_FAMILY,
|
|
|
|
OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY,
|
|
|
|
|
|
|
|
// BLAS3-type
|
|
|
|
OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
|
|
|
|
};
|
|
|
|
|
|
|
|
/** @brief Enumeration for identifying the possible operations */
|
|
|
|
enum operation_node_type
|
|
|
|
{
|
|
|
|
OPERATOR_INVALID_TYPE = 0,
|
|
|
|
|
|
|
|
// unary operator
|
|
|
|
OPERATOR_MINUS_TYPE,
|
2015-01-29 15:19:40 -05:00
|
|
|
OPERATOR_NEGATE_TYPE,
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
// unary expression
|
2015-01-29 15:19:40 -05:00
|
|
|
OPERATOR_CAST_BOOL_TYPE,
|
2015-01-12 13:20:53 -05:00
|
|
|
OPERATOR_CAST_CHAR_TYPE,
|
|
|
|
OPERATOR_CAST_UCHAR_TYPE,
|
|
|
|
OPERATOR_CAST_SHORT_TYPE,
|
|
|
|
OPERATOR_CAST_USHORT_TYPE,
|
|
|
|
OPERATOR_CAST_INT_TYPE,
|
|
|
|
OPERATOR_CAST_UINT_TYPE,
|
|
|
|
OPERATOR_CAST_LONG_TYPE,
|
|
|
|
OPERATOR_CAST_ULONG_TYPE,
|
|
|
|
OPERATOR_CAST_HALF_TYPE,
|
|
|
|
OPERATOR_CAST_FLOAT_TYPE,
|
|
|
|
OPERATOR_CAST_DOUBLE_TYPE,
|
|
|
|
|
|
|
|
OPERATOR_ABS_TYPE,
|
|
|
|
OPERATOR_ACOS_TYPE,
|
|
|
|
OPERATOR_ASIN_TYPE,
|
|
|
|
OPERATOR_ATAN_TYPE,
|
|
|
|
OPERATOR_CEIL_TYPE,
|
|
|
|
OPERATOR_COS_TYPE,
|
|
|
|
OPERATOR_COSH_TYPE,
|
|
|
|
OPERATOR_EXP_TYPE,
|
|
|
|
OPERATOR_FABS_TYPE,
|
|
|
|
OPERATOR_FLOOR_TYPE,
|
|
|
|
OPERATOR_LOG_TYPE,
|
|
|
|
OPERATOR_LOG10_TYPE,
|
|
|
|
OPERATOR_SIN_TYPE,
|
|
|
|
OPERATOR_SINH_TYPE,
|
|
|
|
OPERATOR_SQRT_TYPE,
|
|
|
|
OPERATOR_TAN_TYPE,
|
|
|
|
OPERATOR_TANH_TYPE,
|
|
|
|
OPERATOR_TRANS_TYPE,
|
|
|
|
|
|
|
|
// binary expression
|
|
|
|
OPERATOR_ACCESS_TYPE,
|
|
|
|
OPERATOR_ASSIGN_TYPE,
|
|
|
|
OPERATOR_INPLACE_ADD_TYPE,
|
|
|
|
OPERATOR_INPLACE_SUB_TYPE,
|
|
|
|
OPERATOR_ADD_TYPE,
|
|
|
|
OPERATOR_SUB_TYPE,
|
|
|
|
OPERATOR_MULT_TYPE,
|
|
|
|
OPERATOR_DIV_TYPE,
|
|
|
|
OPERATOR_ELEMENT_ARGFMAX_TYPE,
|
|
|
|
OPERATOR_ELEMENT_ARGFMIN_TYPE,
|
|
|
|
OPERATOR_ELEMENT_ARGMAX_TYPE,
|
|
|
|
OPERATOR_ELEMENT_ARGMIN_TYPE,
|
|
|
|
OPERATOR_ELEMENT_PROD_TYPE,
|
|
|
|
OPERATOR_ELEMENT_DIV_TYPE,
|
|
|
|
OPERATOR_ELEMENT_EQ_TYPE,
|
|
|
|
OPERATOR_ELEMENT_NEQ_TYPE,
|
|
|
|
OPERATOR_ELEMENT_GREATER_TYPE,
|
|
|
|
OPERATOR_ELEMENT_GEQ_TYPE,
|
|
|
|
OPERATOR_ELEMENT_LESS_TYPE,
|
|
|
|
OPERATOR_ELEMENT_LEQ_TYPE,
|
|
|
|
OPERATOR_ELEMENT_POW_TYPE,
|
|
|
|
OPERATOR_ELEMENT_FMAX_TYPE,
|
|
|
|
OPERATOR_ELEMENT_FMIN_TYPE,
|
|
|
|
OPERATOR_ELEMENT_MAX_TYPE,
|
|
|
|
OPERATOR_ELEMENT_MIN_TYPE,
|
|
|
|
|
2015-01-17 10:48:02 -05:00
|
|
|
OPERATOR_OUTER_PROD_TYPE,
|
2015-01-12 13:20:53 -05:00
|
|
|
OPERATOR_MATRIX_DIAG_TYPE,
|
|
|
|
OPERATOR_MATRIX_ROW_TYPE,
|
|
|
|
OPERATOR_MATRIX_COLUMN_TYPE,
|
2015-01-17 10:48:02 -05:00
|
|
|
OPERATOR_REPEAT_TYPE,
|
2015-01-21 20:08:52 -05:00
|
|
|
OPERATOR_SHIFT_TYPE,
|
2015-01-17 10:48:02 -05:00
|
|
|
OPERATOR_VDIAG_TYPE,
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
OPERATOR_MATRIX_PRODUCT_NN_TYPE,
|
|
|
|
OPERATOR_MATRIX_PRODUCT_TN_TYPE,
|
|
|
|
OPERATOR_MATRIX_PRODUCT_NT_TYPE,
|
|
|
|
OPERATOR_MATRIX_PRODUCT_TT_TYPE,
|
|
|
|
|
|
|
|
OPERATOR_PAIR_TYPE
|
|
|
|
};
|
|
|
|
|
2015-01-31 22:01:48 -05:00
|
|
|
/** @brief Groups the type of a node in the array_expression tree. Used for faster dispatching */
|
|
|
|
enum array_expression_node_type_family
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
INVALID_TYPE_FAMILY = 0,
|
|
|
|
COMPOSITE_OPERATOR_FAMILY,
|
|
|
|
VALUE_TYPE_FAMILY,
|
|
|
|
ARRAY_TYPE_FAMILY,
|
|
|
|
INFOS_TYPE_FAMILY
|
|
|
|
};
|
|
|
|
|
2015-01-31 22:01:48 -05:00
|
|
|
/** @brief Encodes the type of a node in the array_expression tree. */
|
|
|
|
enum array_expression_node_subtype
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
INVALID_SUBTYPE = 0,
|
|
|
|
VALUE_SCALAR_TYPE,
|
|
|
|
DENSE_ARRAY_TYPE,
|
2015-01-17 10:48:02 -05:00
|
|
|
REPEAT_INFOS_TYPE
|
2015-01-12 13:20:53 -05:00
|
|
|
};
|
|
|
|
|
2015-01-31 22:01:48 -05:00
|
|
|
struct op_element
|
|
|
|
{
|
|
|
|
op_element();
|
|
|
|
op_element(operation_node_type_family const & _type_family, operation_node_type const & _type);
|
|
|
|
operation_node_type_family type_family;
|
|
|
|
operation_node_type type;
|
|
|
|
};
|
|
|
|
|
2015-01-12 13:20:53 -05:00
|
|
|
struct lhs_rhs_element
|
|
|
|
{
|
|
|
|
lhs_rhs_element();
|
2015-01-31 22:01:48 -05:00
|
|
|
array_expression_node_type_family type_family;
|
|
|
|
array_expression_node_subtype subtype;
|
2015-01-12 13:20:53 -05:00
|
|
|
numeric_type dtype;
|
|
|
|
union
|
|
|
|
{
|
|
|
|
unsigned int node_index;
|
|
|
|
values_holder vscalar;
|
2015-01-18 14:52:45 -05:00
|
|
|
repeat_infos tuple;
|
|
|
|
array_infos array;
|
2015-01-12 13:20:53 -05:00
|
|
|
};
|
2015-01-18 14:52:45 -05:00
|
|
|
cl::Buffer memory_;
|
2015-01-12 13:20:53 -05:00
|
|
|
};
|
|
|
|
|
2015-01-31 22:01:48 -05:00
|
|
|
struct invalid_node{};
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-01-31 22:01:48 -05:00
|
|
|
void fill(lhs_rhs_element &x, invalid_node);
|
|
|
|
void fill(lhs_rhs_element & x, unsigned int node_index);
|
|
|
|
void fill(lhs_rhs_element & x, array const & a);
|
|
|
|
void fill(lhs_rhs_element & x, value_scalar const & v);
|
|
|
|
void fill(lhs_rhs_element & x, repeat_infos const & r);
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-02-04 22:06:15 -05:00
|
|
|
class array_expression : public array_base
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
public:
|
2015-01-31 22:01:48 -05:00
|
|
|
struct node
|
|
|
|
{
|
|
|
|
lhs_rhs_element lhs;
|
|
|
|
op_element op;
|
|
|
|
lhs_rhs_element rhs;
|
|
|
|
};
|
|
|
|
|
|
|
|
typedef std::vector<node> container_type;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-01-31 22:01:48 -05:00
|
|
|
public:
|
|
|
|
template<class LT, class RT>
|
|
|
|
array_expression(LT const & lhs, RT const & rhs, op_element const & op, cl::Context const & ctx, numeric_type const & dtype, size4 const & shape);
|
|
|
|
template<class RT>
|
|
|
|
array_expression(array_expression const & lhs, RT const & rhs, op_element const & op, numeric_type const & dtype, size4 const & shape);
|
|
|
|
template<class LT>
|
|
|
|
array_expression(LT const & lhs, array_expression const & rhs, op_element const & op, numeric_type const & dtype, size4 const & shape);
|
|
|
|
array_expression(array_expression const & lhs, array_expression const & rhs, op_element const & op, numeric_type const & dtype, size4 const & shape);
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-01-31 22:01:48 -05:00
|
|
|
size4 shape() const;
|
|
|
|
array_expression& reshape(int_t size1, int_t size2=1);
|
|
|
|
int_t nshape() const;
|
2015-01-16 07:31:39 -05:00
|
|
|
container_type & tree();
|
|
|
|
container_type const & tree() const;
|
2015-01-12 13:20:53 -05:00
|
|
|
std::size_t root() const;
|
|
|
|
cl::Context const & context() const;
|
|
|
|
numeric_type const & dtype() const;
|
2015-01-16 07:31:39 -05:00
|
|
|
|
2015-01-19 21:29:47 -05:00
|
|
|
array_expression operator-();
|
2015-01-29 15:19:40 -05:00
|
|
|
array_expression operator!();
|
2015-01-12 13:20:53 -05:00
|
|
|
private:
|
2015-01-31 22:01:48 -05:00
|
|
|
container_type tree_;
|
|
|
|
std::size_t root_;
|
|
|
|
cl::Context context_;
|
|
|
|
numeric_type dtype_;
|
2015-01-12 13:20:53 -05:00
|
|
|
size4 shape_;
|
|
|
|
};
|
|
|
|
|
2015-02-05 04:42:57 -05:00
|
|
|
class operation_cache
|
|
|
|
{
|
|
|
|
struct infos
|
|
|
|
{
|
|
|
|
cl::CommandQueue & queue;
|
|
|
|
cl::Kernel kernel;
|
|
|
|
cl::NDRange offset;
|
|
|
|
cl::NDRange global;
|
|
|
|
cl::NDRange local;
|
|
|
|
std::vector<cl::Event>* dependencies;
|
|
|
|
cl::Event* event;
|
|
|
|
};
|
|
|
|
|
|
|
|
public:
|
|
|
|
void push_back(cl::CommandQueue & queue, cl::Kernel const & kernel, cl::NDRange const & offset, cl::NDRange const & global, cl::NDRange const & local, std::vector<cl::Event>* dependencies, cl::Event* event)
|
|
|
|
{ l_.push_back({queue, kernel, offset, global, local, dependencies, event}); }
|
|
|
|
|
|
|
|
void enqueue()
|
|
|
|
{
|
|
|
|
for(infos & i : l_)
|
|
|
|
i.queue.enqueueNDRangeKernel(i.kernel, i.offset, i.global, i.local, i.dependencies, i.event);
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
std::list<infos> l_;
|
|
|
|
};
|
|
|
|
|
|
|
|
struct execution_options_type
|
|
|
|
{
|
|
|
|
execution_options_type(unsigned int _queue_id = 0, cl::Event* _event = NULL, operation_cache* _cache = NULL, std::vector<cl::Event>* _dependencies = NULL) : queue_id(_queue_id), event(_event), cache(_cache), dependencies(_dependencies){}
|
|
|
|
|
|
|
|
void enqueue_cache(cl::CommandQueue & queue, cl::Kernel const & kernel, cl::NDRange offset, cl::NDRange global, cl::NDRange local) const
|
|
|
|
{
|
|
|
|
queue.enqueueNDRangeKernel(kernel, offset, global, local, dependencies, event);
|
|
|
|
if(cache)
|
|
|
|
cache->push_back(queue, kernel, cl::NullRange, global, local, dependencies, event);
|
|
|
|
}
|
|
|
|
|
|
|
|
unsigned int queue_id;
|
|
|
|
cl::Event* event;
|
|
|
|
operation_cache* cache;
|
|
|
|
std::vector<cl::Event>* dependencies;
|
|
|
|
};
|
|
|
|
|
|
|
|
struct dispatcher_options_type
|
|
|
|
{
|
|
|
|
dispatcher_options_type(int _label = -1) : label(_label){}
|
|
|
|
int label;
|
|
|
|
};
|
|
|
|
|
|
|
|
struct compilation_options_type
|
|
|
|
{
|
|
|
|
compilation_options_type(std::string const & _program_name = "", bool _recompile = false) : program_name(_program_name), recompile(_recompile){}
|
|
|
|
std::string program_name;
|
|
|
|
bool recompile;
|
|
|
|
};
|
|
|
|
|
2015-02-04 22:06:15 -05:00
|
|
|
template<class TYPE>
|
|
|
|
class controller
|
2015-02-03 15:20:33 -05:00
|
|
|
{
|
|
|
|
public:
|
2015-02-05 04:42:57 -05:00
|
|
|
controller(TYPE const & x, execution_options_type const& execution_options = execution_options_type(),
|
|
|
|
dispatcher_options_type const & dispatcher_options = dispatcher_options_type(), compilation_options_type const & compilation_options = compilation_options_type())
|
|
|
|
: x_(x), execution_options_(execution_options), dispatcher_options_(dispatcher_options), compilation_options_(compilation_options){}
|
2015-02-03 15:20:33 -05:00
|
|
|
|
2015-02-04 22:06:15 -05:00
|
|
|
TYPE const & x() const { return x_; }
|
2015-02-05 04:42:57 -05:00
|
|
|
execution_options_type const & execution_options() const { return execution_options_; }
|
|
|
|
dispatcher_options_type const & dispatcher_options() const { return dispatcher_options_; }
|
|
|
|
compilation_options_type const & compilation_options() const { return compilation_options_; }
|
2015-02-03 15:20:33 -05:00
|
|
|
private:
|
2015-02-04 22:06:15 -05:00
|
|
|
TYPE const & x_;
|
2015-02-05 04:42:57 -05:00
|
|
|
execution_options_type execution_options_;
|
|
|
|
dispatcher_options_type dispatcher_options_;
|
|
|
|
compilation_options_type compilation_options_;
|
2015-02-03 15:20:33 -05:00
|
|
|
};
|
|
|
|
|
2015-02-05 04:42:57 -05:00
|
|
|
|
2015-02-01 22:28:49 -05:00
|
|
|
class expressions_tuple
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
private:
|
2015-02-04 22:06:15 -05:00
|
|
|
std::shared_ptr<array_expression> create(array_expression const & s);
|
2015-01-12 13:20:53 -05:00
|
|
|
public:
|
2015-02-04 22:06:15 -05:00
|
|
|
typedef std::list<std::shared_ptr<array_expression> > data_type;
|
2015-01-12 13:20:53 -05:00
|
|
|
enum order_type { SEQUENTIAL, INDEPENDENT };
|
|
|
|
|
2015-02-01 22:28:49 -05:00
|
|
|
expressions_tuple(array_expression const & s0);
|
|
|
|
expressions_tuple(order_type order, array_expression const & s0, array_expression const & s1);
|
|
|
|
expressions_tuple(data_type const & data, order_type order);
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
data_type const & data() const;
|
|
|
|
cl::Context const & context() const;
|
|
|
|
order_type order() const;
|
|
|
|
private:
|
|
|
|
data_type data_;
|
|
|
|
order_type order_;
|
|
|
|
};
|
|
|
|
|
2015-01-31 22:01:48 -05:00
|
|
|
array_expression::node const & lhs_most(array_expression::container_type const & array, array_expression::node const & init);
|
|
|
|
array_expression::node const & lhs_most(array_expression::container_type const & array, size_t root);
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|