Cleaning: Largely renamed templates to BLAS-like names

This commit is contained in:
Philippe Tillet
2015-07-11 09:36:01 -04:00
parent 281fa9c7a6
commit cfa6ea812d
40 changed files with 606 additions and 572 deletions

View File

@@ -219,15 +219,15 @@ array_expression norm(array_expression const &, unsigned int order = 2);
array_expression repmat(array const &, int_t const & rep1, int_t const & rep2); array_expression repmat(array const &, int_t const & rep1, int_t const & rep2);
#define ISAAC_DECLARE_REDUCTION(OPNAME) \ #define ISAAC_DECLARE_DOT(OPNAME) \
array_expression OPNAME(array const & M, int_t axis = -1);\ array_expression OPNAME(array const & M, int_t axis = -1);\
array_expression OPNAME(array_expression const & M, int_t axis = -1); array_expression OPNAME(array_expression const & M, int_t axis = -1);
ISAAC_DECLARE_REDUCTION(sum) ISAAC_DECLARE_DOT(sum)
ISAAC_DECLARE_REDUCTION(argmax) ISAAC_DECLARE_DOT(argmax)
ISAAC_DECLARE_REDUCTION(max) ISAAC_DECLARE_DOT(max)
ISAAC_DECLARE_REDUCTION(min) ISAAC_DECLARE_DOT(min)
ISAAC_DECLARE_REDUCTION(argmin) ISAAC_DECLARE_DOT(argmin)
array_expression eye(std::size_t, std::size_t, isaac::numeric_type, driver::Context context = driver::queues.default_context()); array_expression eye(std::size_t, std::size_t, isaac::numeric_type, driver::Context context = driver::queues.default_context());
array_expression zeros(std::size_t M, std::size_t N, numeric_type dtype, driver::Context context = driver::queues.default_context()); array_expression zeros(std::size_t M, std::size_t N, numeric_type dtype, driver::Context context = driver::queues.default_context());

View File

@@ -84,46 +84,46 @@ protected:
* *
* Maps prod(matrix_expression, matrix_expression) * Maps prod(matrix_expression, matrix_expression)
*/ */
class mapped_mproduct : public mapped_object, public binary_leaf class mapped_gemm : public mapped_object, public binary_leaf
{ {
public: public:
mapped_mproduct(std::string const & scalartype, unsigned int id, node_info info); mapped_gemm(std::string const & scalartype, unsigned int id, node_info info);
}; };
/** @brief Reduction /** @brief Reduction
* *
* Base class for mapping a reduction * Base class for mapping a dot
*/ */
class mapped_reduction : public mapped_object, public binary_leaf class mapped_dot : public mapped_object, public binary_leaf
{ {
public: public:
mapped_reduction(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key); mapped_dot(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key);
int_t root_idx() const; int_t root_idx() const;
isaac::array_expression const & array_expression() const; isaac::array_expression const & array_expression() const;
array_expression::node root_node() const; array_expression::node root_node() const;
bool is_index_reduction() const; bool is_index_dot() const;
op_element root_op() const; op_element root_op() const;
}; };
/** @brief Scalar reduction /** @brief Scalar dot
* *
* Maps a scalar reduction (max, min, argmax, inner_prod, etc..) * Maps a scalar dot (max, min, argmax, inner_prod, etc..)
*/ */
class mapped_scalar_reduction : public mapped_reduction class mapped_scalar_dot : public mapped_dot
{ {
public: public:
mapped_scalar_reduction(std::string const & scalartype, unsigned int id, node_info info); mapped_scalar_dot(std::string const & scalartype, unsigned int id, node_info info);
}; };
/** @brief Vector reduction /** @brief Vector dot
* *
* Maps a row-wise reduction (max, min, argmax, matrix-vector product, etc..) * Maps a row-wise dot (max, min, argmax, matrix-vector product, etc..)
*/ */
class mapped_mreduction : public mapped_reduction class mapped_gemv : public mapped_dot
{ {
public: public:
mapped_mreduction(std::string const & scalartype, unsigned int id, node_info info); mapped_gemv(std::string const & scalartype, unsigned int id, node_info info);
}; };
/** @brief Host scalar /** @brief Host scalar

View File

@@ -13,8 +13,8 @@ namespace detail
{ {
bool is_node_leaf(op_element const & op); bool is_node_leaf(op_element const & op);
bool is_scalar_reduction(array_expression::node const & node); bool is_scalar_dot(array_expression::node const & node);
bool is_vector_reduction(array_expression::node const & node); bool is_vector_dot(array_expression::node const & node);
bool is_assignment(op_element const & op); bool is_assignment(op_element const & op);
bool is_elementwise_operator(op_element const & op); bool is_elementwise_operator(op_element const & op);
bool is_elementwise_function(op_element const & op); bool is_elementwise_function(op_element const & op);

View File

@@ -5,27 +5,30 @@
namespace isaac namespace isaac
{ {
namespace templates
{
class vaxpy_parameters : public base::parameters_type class axpy_parameters : public base::parameters_type
{ {
public: public:
vaxpy_parameters(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy); axpy_parameters(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy);
unsigned int num_groups; unsigned int num_groups;
fetching_policy_type fetching_policy; fetching_policy_type fetching_policy;
}; };
class vaxpy : public base_impl<vaxpy, vaxpy_parameters> class axpy : public base_impl<axpy, axpy_parameters>
{ {
private: private:
virtual int is_invalid_impl(driver::Device const &, expressions_tuple const &) const; virtual int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const; std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const;
public: public:
vaxpy(vaxpy::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE); axpy(axpy::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
vaxpy(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy, binding_policy_t binding_policy = BIND_ALL_UNIQUE); axpy(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const; std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &); void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
}; };
}
} }
#endif #endif

View File

@@ -14,6 +14,9 @@
namespace isaac namespace isaac
{ {
namespace templates
{
enum fetching_policy_type enum fetching_policy_type
{ {
FETCH_FROM_LOCAL, FETCH_FROM_LOCAL,
@@ -147,8 +150,8 @@ protected:
} }
} }
static void compute_reduction(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op); static void compute_dot(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op);
static void compute_index_reduction(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op); static void compute_index_dot(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op);
static void process_all(std::string const & type_key, std::string const & str, kernel_generation_stream & stream, std::vector<mapping_type> const & mappings); static void process_all(std::string const & type_key, std::string const & str, kernel_generation_stream & stream, std::vector<mapping_type> const & mappings);
static void process_all_at(std::string const & type_key, std::string const & str, kernel_generation_stream & stream, std::vector<mapping_type> const & mappings, size_t root_idx, leaf_t leaf); static void process_all_at(std::string const & type_key, std::string const & str, kernel_generation_stream & stream, std::vector<mapping_type> const & mappings, size_t root_idx, leaf_t leaf);
static std::string neutral_element(op_element const & op, driver::backend_type backend, std::string const & datatype); static std::string neutral_element(op_element const & op, driver::backend_type backend, std::string const & datatype);
@@ -159,8 +162,8 @@ protected:
static bool is_strided(array_expression::node const & node); static bool is_strided(array_expression::node const & node);
static int_t vector_size(array_expression::node const & node); static int_t vector_size(array_expression::node const & node);
static std::pair<int_t, int_t> matrix_size(array_expression::node const & node); static std::pair<int_t, int_t> matrix_size(array_expression::node const & node);
static bool is_reduction(array_expression::node const & node); static bool is_dot(array_expression::node const & node);
static bool is_index_reduction(op_element const & op); static bool is_index_dot(op_element const & op);
static std::string access_vector_type(std::string const & v, int i); static std::string access_vector_type(std::string const & v, int i);
tools::shared_ptr<symbolic_binder> make_binder(); tools::shared_ptr<symbolic_binder> make_binder();
@@ -204,6 +207,7 @@ protected:
binding_policy_t binding_policy_; binding_policy_t binding_policy_;
}; };
}
} }
#endif #endif

View File

@@ -1,32 +1,34 @@
#ifndef ISAAC_BACKEND_TEMPLATES_REDUCTION_H #ifndef ISAAC_BACKEND_TEMPLATES_DOT_H
#define ISAAC_BACKEND_TEMPLATES_REDUCTION_H #define ISAAC_BACKEND_TEMPLATES_DOT_H
#include "isaac/backend/templates/base.h" #include "isaac/backend/templates/base.h"
namespace isaac namespace isaac
{ {
namespace templates
struct reduction_parameters : public base::parameters_type
{ {
reduction_parameters(unsigned int _simd_width,
struct dot_parameters : public base::parameters_type
{
dot_parameters(unsigned int _simd_width,
unsigned int _group_size, unsigned int _num_groups, unsigned int _group_size, unsigned int _num_groups,
fetching_policy_type _fetching_policy); fetching_policy_type _fetching_policy);
unsigned int num_groups; unsigned int num_groups;
fetching_policy_type fetching_policy; fetching_policy_type fetching_policy;
}; };
class reduction : public base_impl<reduction, reduction_parameters> class dot : public base_impl<dot, dot_parameters>
{ {
private: private:
unsigned int lmem_usage(expressions_tuple const & expressions) const; unsigned int lmem_usage(expressions_tuple const & expressions) const;
int is_invalid_impl(driver::Device const &, expressions_tuple const &) const; int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_reduction*> exprs, inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_dot*> exprs,
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const; std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const;
std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const; std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const;
public: public:
reduction(reduction::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE); dot(dot::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
reduction(unsigned int simd, unsigned int ls, unsigned int ng, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE); dot(unsigned int simd, unsigned int ls, unsigned int ng, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const; std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &); void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
private: private:
@@ -34,6 +36,7 @@ private:
std::vector< driver::Buffer > tmpidx_; std::vector< driver::Buffer > tmpidx_;
}; };
}
} }
#endif #endif

View File

@@ -7,12 +7,13 @@
namespace isaac namespace isaac
{ {
namespace templates
{
class model; class model;
struct mproduct_parameters : public base::parameters_type struct gemm_parameters : public base::parameters_type
{ {
mproduct_parameters(unsigned int simd_width gemm_parameters(unsigned int simd_width
, int_t local_size_0, int_t KL, int_t local_size_1, int_t D , int_t local_size_0, int_t KL, int_t local_size_1, int_t D
, int_t ms, int_t ks, int_t ns , int_t ms, int_t ks, int_t ns
, fetching_policy_type A_fetching_policy, fetching_policy_type B_fetching_policy , fetching_policy_type A_fetching_policy, fetching_policy_type B_fetching_policy
@@ -38,7 +39,7 @@ struct mproduct_parameters : public base::parameters_type
bool unroll_outer; bool unroll_outer;
}; };
class mproduct : public base_impl<mproduct, mproduct_parameters> class gemm : public base_impl<gemm, gemm_parameters>
{ {
private: private:
unsigned int lmem_usage(expressions_tuple const & expressions) const; unsigned int lmem_usage(expressions_tuple const & expressions) const;
@@ -50,7 +51,7 @@ private:
array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap); array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap);
std::vector<int_t> infos(expressions_tuple const & expressions, isaac::symbolic::preset::gemm::args &arguments) const; std::vector<int_t> infos(expressions_tuple const & expressions, isaac::symbolic::preset::gemm::args &arguments) const;
public: public:
mproduct(mproduct::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans); gemm(gemm::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans);
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const; std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
void cleanup(values_holder beta, controller<expressions_tuple> const & ctr, model & fallback, void cleanup(values_holder beta, controller<expressions_tuple> const & ctr, model & fallback,
lhs_rhs_element* eA, lhs_rhs_element* eB, lhs_rhs_element* eC, lhs_rhs_element* ebeta, array const & A, array const & B, array const & C); lhs_rhs_element* eA, lhs_rhs_element* eB, lhs_rhs_element* eC, lhs_rhs_element* ebeta, array const & A, array const & B, array const & C);
@@ -62,41 +63,41 @@ private:
bool check_bounds_; bool check_bounds_;
}; };
class mproduct_nn : public mproduct class gemm_nn : public gemm
{ {
public: public:
mproduct_nn(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D gemm_nn(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch , int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound = false); , int_t lfetch0, int_t lfetch1, bool check_bound = false);
}; };
class mproduct_tn : public mproduct class gemm_tn : public gemm
{ {
public: public:
mproduct_tn(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D gemm_tn(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch , int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound = false); , int_t lfetch0, int_t lfetch1, bool check_bound = false);
}; };
class mproduct_nt : public mproduct class gemm_nt : public gemm
{ {
public: public:
mproduct_nt(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D gemm_nt(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch , int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound = false); , int_t lfetch0, int_t lfetch1, bool check_bound = false);
}; };
class mproduct_tt : public mproduct class gemm_tt : public gemm
{ {
public: public:
mproduct_tt(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D gemm_tt(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch , int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound = false); , int_t lfetch0, int_t lfetch1, bool check_bound = false);
}; };
}
} }
#endif #endif

View File

@@ -0,0 +1,61 @@
#ifndef ISAAC_BACKEND_TEMPLATES_MDOT_H
#define ISAAC_BACKEND_TEMPLATES_MDOT_H
#include <vector>
#include "isaac/symbolic/expression.h"
#include "isaac/backend/templates/base.h"
namespace isaac
{
namespace templates
{
struct gemv_parameters : public base::parameters_type
{
gemv_parameters(unsigned int _simd_width,
unsigned int _local_size_0, unsigned int _local_size_1,
unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetch_policy);
unsigned int num_groups_0;
unsigned int num_groups_1;
fetching_policy_type fetch_policy;
};
class gemv : public base_impl<gemv, gemv_parameters>
{
protected:
enum dot_type
{
REDUCE_ROWS,
REDUCE_COLUMNS
};
gemv(gemv::parameters_type const & , dot_type, binding_policy_t);
private:
virtual int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
unsigned int lmem_usage() const;
std::string generate_impl(const char * suffix, expressions_tuple const &, driver::Device const & device, std::vector<mapping_type> const &) const;
public:
virtual std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
private:
dot_type dot_type_;
};
class gemv_n : public gemv
{
public:
gemv_n(gemv::parameters_type const &, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
gemv_n(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
};
class gemv_t : public gemv
{
public:
gemv_t(gemv::parameters_type const &, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
gemv_t(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
};
}
}
#endif

View File

@@ -6,29 +6,32 @@
namespace isaac namespace isaac
{ {
namespace templates
{
class maxpy_parameters : public base::parameters_type class ger_parameters : public base::parameters_type
{ {
public: public:
maxpy_parameters(unsigned int _simd_width, unsigned int _local_size_0, unsigned int _local_size_1, unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetching_policy); ger_parameters(unsigned int _simd_width, unsigned int _local_size_0, unsigned int _local_size_1, unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetching_policy);
unsigned int num_groups_0; unsigned int num_groups_0;
unsigned int num_groups_1; unsigned int num_groups_1;
fetching_policy_type fetching_policy; fetching_policy_type fetching_policy;
}; };
class maxpy : public base_impl<maxpy, maxpy_parameters> class ger : public base_impl<ger, ger_parameters>
{ {
private: private:
int is_invalid_impl(driver::Device const &, expressions_tuple const &) const; int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const; std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const;
public: public:
maxpy(parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE); ger(parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
maxpy(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE); ger(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const; std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &); void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
}; };
}
} }
#endif #endif

View File

@@ -1,59 +0,0 @@
#ifndef ISAAC_BACKEND_TEMPLATES_MREDUCTION_H
#define ISAAC_BACKEND_TEMPLATES_MREDUCTION_H
#include <vector>
#include "isaac/symbolic/expression.h"
#include "isaac/backend/templates/base.h"
namespace isaac
{
struct mreduction_parameters : public base::parameters_type
{
mreduction_parameters(unsigned int _simd_width,
unsigned int _local_size_0, unsigned int _local_size_1,
unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetch_policy);
unsigned int num_groups_0;
unsigned int num_groups_1;
fetching_policy_type fetch_policy;
};
class mreduction : public base_impl<mreduction, mreduction_parameters>
{
protected:
enum reduction_type
{
REDUCE_ROWS,
REDUCE_COLUMNS
};
mreduction(mreduction::parameters_type const & , reduction_type, binding_policy_t);
private:
virtual int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
unsigned int lmem_usage() const;
std::string generate_impl(const char * suffix, expressions_tuple const &, driver::Device const & device, std::vector<mapping_type> const &) const;
public:
virtual std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
private:
reduction_type reduction_type_;
};
class mreduction_rows : public mreduction
{
public:
mreduction_rows(mreduction::parameters_type const &, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
mreduction_rows(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
};
class mreduction_cols : public mreduction
{
public:
mreduction_cols(mreduction::parameters_type const &, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
mreduction_cols(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
};
}
#endif

View File

@@ -14,7 +14,7 @@ namespace isaac
class model class model
{ {
typedef tools::shared_ptr<base> template_pointer; typedef tools::shared_ptr<templates::base> template_pointer;
typedef std::vector< template_pointer > templates_container; typedef std::vector< template_pointer > templates_container;
private: private:
@@ -23,8 +23,8 @@ namespace isaac
driver::Program& init(controller<expressions_tuple> const &); driver::Program& init(controller<expressions_tuple> const &);
public: public:
model(expression_type, numeric_type, predictors::random_forest const &, std::vector< tools::shared_ptr<base> > const &, driver::CommandQueue const &); model(expression_type, numeric_type, predictors::random_forest const &, std::vector< tools::shared_ptr<templates::base> > const &, driver::CommandQueue const &);
model(expression_type, numeric_type, base const &, driver::CommandQueue const &); model(expression_type, numeric_type, templates::base const &, driver::CommandQueue const &);
void execute(controller<expressions_tuple> const &); void execute(controller<expressions_tuple> const &);
templates_container const & templates() const; templates_container const & templates() const;
@@ -46,7 +46,7 @@ namespace isaac
model_map_t init_models(driver::CommandQueue const & queue); model_map_t init_models(driver::CommandQueue const & queue);
model_map_t& models(driver::CommandQueue & queue); model_map_t& models(driver::CommandQueue & queue);
extern std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<base> > fallbacks; extern std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > fallbacks;
extern std::map<driver::CommandQueue, model_map_t> models_; extern std::map<driver::CommandQueue, model_map_t> models_;
} }

View File

@@ -31,14 +31,14 @@ enum operation_node_type_family
// BLAS1-type // BLAS1-type
OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_UNARY_TYPE_FAMILY,
OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_BINARY_TYPE_FAMILY,
OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY, OPERATOR_VECTOR_DOT_TYPE_FAMILY,
// BLAS2-type // BLAS2-type
OPERATOR_ROWS_REDUCTION_TYPE_FAMILY, OPERATOR_ROWS_DOT_TYPE_FAMILY,
OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OPERATOR_COLUMNS_DOT_TYPE_FAMILY,
// BLAS3-type // BLAS3-type
OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY OPERATOR_GEMM_TYPE_FAMILY
}; };
/** @brief Enumeration for identifying the possible operations */ /** @brief Enumeration for identifying the possible operations */
@@ -119,10 +119,10 @@ enum operation_node_type
OPERATOR_SHIFT_TYPE, OPERATOR_SHIFT_TYPE,
OPERATOR_VDIAG_TYPE, OPERATOR_VDIAG_TYPE,
OPERATOR_MATRIX_PRODUCT_NN_TYPE, OPERATOR_GEMM_NN_TYPE,
OPERATOR_MATRIX_PRODUCT_TN_TYPE, OPERATOR_GEMM_TN_TYPE,
OPERATOR_MATRIX_PRODUCT_NT_TYPE, OPERATOR_GEMM_NT_TYPE,
OPERATOR_MATRIX_PRODUCT_TT_TYPE, OPERATOR_GEMM_TT_TYPE,
OPERATOR_PAIR_TYPE OPERATOR_PAIR_TYPE
}; };

View File

@@ -128,16 +128,15 @@ template<> struct to_numeric_type<double> { static const numeric_type value = DO
enum expression_type enum expression_type
{ {
INVALID_EXPRESSION_TYPE, INVALID_EXPRESSION_TYPE,
SCALAR_AXPY_TYPE, AXPY_TYPE,
VECTOR_AXPY_TYPE, GER_TYPE,
MATRIX_AXPY_TYPE, DOT_TYPE,
REDUCTION_TYPE, GEMV_N_TYPE,
ROW_WISE_REDUCTION_TYPE, GEMV_T_TYPE,
COL_WISE_REDUCTION_TYPE, GEMM_NN_TYPE,
MATRIX_PRODUCT_NN_TYPE, GEMM_TN_TYPE,
MATRIX_PRODUCT_TN_TYPE, GEMM_NT_TYPE,
MATRIX_PRODUCT_NT_TYPE, GEMM_TT_TYPE
MATRIX_PRODUCT_TT_TYPE
}; };
struct slice struct slice

View File

@@ -596,17 +596,17 @@ array_expression repmat(array_expression const & A, int_t const & rep1, int_t co
///*--- Reductions ---*/ ///*--- Reductions ---*/
////--------------------------------------- ////---------------------------------------
#define DEFINE_REDUCTION(OP, OPNAME)\ #define DEFINE_DOT(OP, OPNAME)\
array_expression OPNAME(array const & x, int_t axis)\ array_expression OPNAME(array const & x, int_t axis)\
{\ {\
if(axis < -1 || axis > x.nshape())\ if(axis < -1 || axis > x.nshape())\
throw std::out_of_range("The axis entry is out of bounds");\ throw std::out_of_range("The axis entry is out of bounds");\
else if(axis==-1)\ else if(axis==-1)\
return array_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\ return array_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\
else if(axis==0)\ else if(axis==0)\
return array_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\ return array_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\
else\ else\
return array_expression(x, invalid_node(), op_element(OPERATOR_ROWS_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\ return array_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\
}\ }\
\ \
array_expression OPNAME(array_expression const & x, int_t axis)\ array_expression OPNAME(array_expression const & x, int_t axis)\
@@ -614,20 +614,20 @@ array_expression OPNAME(array_expression const & x, int_t axis)\
if(axis < -1 || axis > x.nshape())\ if(axis < -1 || axis > x.nshape())\
throw std::out_of_range("The axis entry is out of bounds");\ throw std::out_of_range("The axis entry is out of bounds");\
if(axis==-1)\ if(axis==-1)\
return array_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\ return array_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\
else if(axis==0)\ else if(axis==0)\
return array_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\ return array_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\
else\ else\
return array_expression(x, invalid_node(), op_element(OPERATOR_ROWS_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\ return array_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\
} }
DEFINE_REDUCTION(OPERATOR_ADD_TYPE, sum) DEFINE_DOT(OPERATOR_ADD_TYPE, sum)
DEFINE_REDUCTION(OPERATOR_ELEMENT_ARGMAX_TYPE, argmax) DEFINE_DOT(OPERATOR_ELEMENT_ARGMAX_TYPE, argmax)
DEFINE_REDUCTION(OPERATOR_ELEMENT_MAX_TYPE, max) DEFINE_DOT(OPERATOR_ELEMENT_MAX_TYPE, max)
DEFINE_REDUCTION(OPERATOR_ELEMENT_MIN_TYPE, min) DEFINE_DOT(OPERATOR_ELEMENT_MIN_TYPE, min)
DEFINE_REDUCTION(OPERATOR_ELEMENT_ARGMIN_TYPE, argmin) DEFINE_DOT(OPERATOR_ELEMENT_ARGMIN_TYPE, argmin)
#undef DEFINE_REDUCTION #undef DEFINE_DOT
namespace detail namespace detail
{ {
@@ -635,21 +635,21 @@ namespace detail
array_expression matmatprod(array const & A, array const & B) array_expression matmatprod(array const & A, array const & B)
{ {
size4 shape(A.shape()[0], B.shape()[1]); size4 shape(A.shape()[0], B.shape()[1]);
return array_expression(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, OPERATOR_MATRIX_PRODUCT_NN_TYPE), A.context(), A.dtype(), shape); return array_expression(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, OPERATOR_GEMM_NN_TYPE), A.context(), A.dtype(), shape);
} }
array_expression matmatprod(array_expression const & A, array const & B) array_expression matmatprod(array_expression const & A, array const & B)
{ {
operation_node_type type = OPERATOR_MATRIX_PRODUCT_NN_TYPE; operation_node_type type = OPERATOR_GEMM_NN_TYPE;
size4 shape(A.shape()[0], B.shape()[1]); size4 shape(A.shape()[0], B.shape()[1]);
array_expression::node & A_root = const_cast<array_expression::node &>(A.tree()[A.root()]); array_expression::node & A_root = const_cast<array_expression::node &>(A.tree()[A.root()]);
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE; bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
if(A_trans){ if(A_trans){
type = OPERATOR_MATRIX_PRODUCT_TN_TYPE; type = OPERATOR_GEMM_TN_TYPE;
} }
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape); array_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]); array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs; if(A_trans) res_root.lhs = A_root.lhs;
return res; return res;
@@ -657,16 +657,16 @@ namespace detail
array_expression matmatprod(array const & A, array_expression const & B) array_expression matmatprod(array const & A, array_expression const & B)
{ {
operation_node_type type = OPERATOR_MATRIX_PRODUCT_NN_TYPE; operation_node_type type = OPERATOR_GEMM_NN_TYPE;
size4 shape(A.shape()[0], B.shape()[1]); size4 shape(A.shape()[0], B.shape()[1]);
array_expression::node & B_root = const_cast<array_expression::node &>(B.tree()[B.root()]); array_expression::node & B_root = const_cast<array_expression::node &>(B.tree()[B.root()]);
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE; bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
if(B_trans){ if(B_trans){
type = OPERATOR_MATRIX_PRODUCT_NT_TYPE; type = OPERATOR_GEMM_NT_TYPE;
} }
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape); array_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]); array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
if(B_trans) res_root.rhs = B_root.lhs; if(B_trans) res_root.rhs = B_root.lhs;
return res; return res;
@@ -674,7 +674,7 @@ namespace detail
array_expression matmatprod(array_expression const & A, array_expression const & B) array_expression matmatprod(array_expression const & A, array_expression const & B)
{ {
operation_node_type type = OPERATOR_MATRIX_PRODUCT_NN_TYPE; operation_node_type type = OPERATOR_GEMM_NN_TYPE;
array_expression::node & A_root = const_cast<array_expression::node &>(A.tree()[A.root()]); array_expression::node & A_root = const_cast<array_expression::node &>(A.tree()[A.root()]);
array_expression::node & B_root = const_cast<array_expression::node &>(B.tree()[B.root()]); array_expression::node & B_root = const_cast<array_expression::node &>(B.tree()[B.root()]);
size4 shape(A.shape()[0], B.shape()[1]); size4 shape(A.shape()[0], B.shape()[1]);
@@ -682,12 +682,12 @@ namespace detail
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE; bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE; bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
if(A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_TT_TYPE; if(A_trans && B_trans) type = OPERATOR_GEMM_TT_TYPE;
else if(A_trans && !B_trans) type = OPERATOR_MATRIX_PRODUCT_TN_TYPE; else if(A_trans && !B_trans) type = OPERATOR_GEMM_TN_TYPE;
else if(!A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_NT_TYPE; else if(!A_trans && B_trans) type = OPERATOR_GEMM_NT_TYPE;
else type = OPERATOR_MATRIX_PRODUCT_NN_TYPE; else type = OPERATOR_GEMM_NN_TYPE;
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape); array_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]); array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs; if(A_trans) res_root.lhs = A_root.lhs;
if(B_trans) res_root.rhs = B_root.lhs; if(B_trans) res_root.rhs = B_root.lhs;

View File

@@ -102,23 +102,23 @@ std::string binary_leaf::evaluate_recursive(leaf_t leaf, std::map<std::string, s
} }
mapped_mproduct::mapped_mproduct(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "mproduct"), binary_leaf(info) { } mapped_gemm::mapped_gemm(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "gemm"), binary_leaf(info) { }
// //
mapped_reduction::mapped_reduction(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key) : mapped_dot::mapped_dot(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key) :
mapped_object(scalartype, id, type_key), binary_leaf(info) mapped_object(scalartype, id, type_key), binary_leaf(info)
{ } { }
int_t mapped_reduction::root_idx() const int_t mapped_dot::root_idx() const
{ return info_.root_idx; } { return info_.root_idx; }
isaac::array_expression const & mapped_reduction::array_expression() const isaac::array_expression const & mapped_dot::array_expression() const
{ return *info_.array_expression; } { return *info_.array_expression; }
array_expression::node mapped_reduction::root_node() const array_expression::node mapped_dot::root_node() const
{ return array_expression().tree()[root_idx()]; } { return array_expression().tree()[root_idx()]; }
bool mapped_reduction::is_index_reduction() const bool mapped_dot::is_index_dot() const
{ {
op_element const & op = root_op(); op_element const & op = root_op();
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
@@ -127,17 +127,17 @@ bool mapped_reduction::is_index_reduction() const
|| op.type==OPERATOR_ELEMENT_ARGMIN_TYPE; || op.type==OPERATOR_ELEMENT_ARGMIN_TYPE;
} }
op_element mapped_reduction::root_op() const op_element mapped_dot::root_op() const
{ {
return info_.array_expression->tree()[info_.root_idx].op; return info_.array_expression->tree()[info_.root_idx].op;
} }
// //
mapped_scalar_reduction::mapped_scalar_reduction(std::string const & scalartype, unsigned int id, node_info info) : mapped_reduction(scalartype, id, info, "scalar_reduction"){ } mapped_scalar_dot::mapped_scalar_dot(std::string const & scalartype, unsigned int id, node_info info) : mapped_dot(scalartype, id, info, "scalar_dot"){ }
// //
mapped_mreduction::mapped_mreduction(std::string const & scalartype, unsigned int id, node_info info) : mapped_reduction(scalartype, id, info, "mreduction") { } mapped_gemv::mapped_gemv(std::string const & scalartype, unsigned int id, node_info info) : mapped_dot(scalartype, id, info, "gemv") { }
// //
void mapped_host_scalar::preprocess(std::string & str) const void mapped_host_scalar::preprocess(std::string & str) const

View File

@@ -10,15 +10,15 @@ namespace detail
bool is_scalar_reduction(array_expression::node const & node) bool is_scalar_dot(array_expression::node const & node)
{ {
return node.op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY; return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY;
} }
bool is_vector_reduction(array_expression::node const & node) bool is_vector_dot(array_expression::node const & node)
{ {
return node.op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY return node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY; || node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY;
} }
bool is_assignment(op_element const & op) bool is_assignment(op_element const & op)
@@ -75,10 +75,10 @@ namespace detail
|| op.type==OPERATOR_MATRIX_ROW_TYPE || op.type==OPERATOR_MATRIX_ROW_TYPE
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE || op.type==OPERATOR_MATRIX_COLUMN_TYPE
|| op.type==OPERATOR_OUTER_PROD_TYPE || op.type==OPERATOR_OUTER_PROD_TYPE
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY || op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY || op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY || op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY || op.type_family==OPERATOR_GEMM_TYPE_FAMILY
; ;
} }
@@ -214,10 +214,10 @@ const char * evaluate(operation_node_type type)
case OPERATOR_ELEMENT_MIN_TYPE : return "min"; case OPERATOR_ELEMENT_MIN_TYPE : return "min";
//Binary //Binary
case OPERATOR_MATRIX_PRODUCT_NN_TYPE : return "prodNN"; case OPERATOR_GEMM_NN_TYPE : return "prodNN";
case OPERATOR_MATRIX_PRODUCT_TN_TYPE : return "prodTN"; case OPERATOR_GEMM_TN_TYPE : return "prodTN";
case OPERATOR_MATRIX_PRODUCT_NT_TYPE : return "prodNT"; case OPERATOR_GEMM_NT_TYPE : return "prodNT";
case OPERATOR_MATRIX_PRODUCT_TT_TYPE : return "prodTT"; case OPERATOR_GEMM_TT_TYPE : return "prodTT";
case OPERATOR_VDIAG_TYPE : return "vdiag"; case OPERATOR_VDIAG_TYPE : return "vdiag";
case OPERATOR_MATRIX_DIAG_TYPE : return "mdiag"; case OPERATOR_MATRIX_DIAG_TYPE : return "mdiag";
case OPERATOR_MATRIX_ROW_TYPE : return "row"; case OPERATOR_MATRIX_ROW_TYPE : return "row";

View File

@@ -1,4 +1,4 @@
#include "isaac/backend/templates/vaxpy.h" #include "isaac/backend/templates/axpy.h"
#include "isaac/backend/keywords.h" #include "isaac/backend/keywords.h"
#include "isaac/driver/backend.h" #include "isaac/driver/backend.h"
#include "isaac/tools/make_map.hpp" #include "isaac/tools/make_map.hpp"
@@ -8,23 +8,24 @@
namespace isaac namespace isaac
{ {
namespace templates
{
axpy_parameters::axpy_parameters(unsigned int _simd_width,
vaxpy_parameters::vaxpy_parameters(unsigned int _simd_width,
unsigned int _group_size, unsigned int _num_groups, unsigned int _group_size, unsigned int _num_groups,
fetching_policy_type _fetching_policy) : fetching_policy_type _fetching_policy) :
base::parameters_type(_simd_width, _group_size, 1, 1), num_groups(_num_groups), fetching_policy(_fetching_policy) base::parameters_type(_simd_width, _group_size, 1, 1), num_groups(_num_groups), fetching_policy(_fetching_policy)
{ } { }
int vaxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const int axpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
{ {
if (p_.fetching_policy==FETCH_FROM_LOCAL) if (p_.fetching_policy==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE; return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID; return TEMPLATE_VALID;
} }
std::string vaxpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const std::string axpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
{ {
driver::backend_type backend = device.backend(); driver::backend_type backend = device.backend();
std::string _size_t = size_type(device); std::string _size_t = size_type(device);
@@ -90,25 +91,24 @@ std::string vaxpy::generate_impl(const char * suffix, expressions_tuple const &
return stream.str(); return stream.str();
} }
vaxpy::vaxpy(vaxpy_parameters const & parameters, axpy::axpy(axpy_parameters const & parameters,
binding_policy_t binding_policy) : binding_policy_t binding_policy) :
base_impl<vaxpy, vaxpy_parameters>(parameters, binding_policy) base_impl<axpy, axpy_parameters>(parameters, binding_policy)
{} {}
vaxpy::vaxpy(unsigned int simd, unsigned int ls, unsigned int ng, axpy::axpy(unsigned int simd, unsigned int ls, unsigned int ng,
fetching_policy_type fetch, binding_policy_t bind): fetching_policy_type fetch, binding_policy_t bind):
base_impl<vaxpy, vaxpy_parameters>(vaxpy_parameters(simd,ls,ng,fetch), bind) base_impl<axpy, axpy_parameters>(axpy_parameters(simd,ls,ng,fetch), bind)
{} {}
std::vector<int_t> vaxpy::input_sizes(expressions_tuple const & expressions) const std::vector<int_t> axpy::input_sizes(expressions_tuple const & expressions) const
{ {
size4 shape = static_cast<array_expression const *>(expressions.data().front().get())->shape(); size4 shape = static_cast<array_expression const *>(expressions.data().front().get())->shape();
int_t size = static_cast<array_expression const *>(expressions.data().front().get())->shape()[0];
return tools::make_vector<int_t>() << std::max(shape[0], shape[1]); return tools::make_vector<int_t>() << std::max(shape[0], shape[1]);
} }
void vaxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller) void axpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
{ {
expressions_tuple const & expressions = controller.x(); expressions_tuple const & expressions = controller.x();
//Size //Size
@@ -135,3 +135,4 @@ void vaxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, con
} }
}

View File

@@ -2,11 +2,11 @@
#include "isaac/array.h" #include "isaac/array.h"
#include "isaac/backend/keywords.h" #include "isaac/backend/keywords.h"
#include "isaac/backend/templates/vaxpy.h" #include "isaac/backend/templates/axpy.h"
#include "isaac/backend/templates/reduction.h" #include "isaac/backend/templates/dot.h"
#include "isaac/backend/templates/maxpy.h" #include "isaac/backend/templates/ger.h"
#include "isaac/backend/templates/mreduction.h" #include "isaac/backend/templates/gemv.h"
#include "isaac/backend/templates/mproduct.h" #include "isaac/backend/templates/gemm.h"
#include "isaac/backend/templates/base.h" #include "isaac/backend/templates/base.h"
#include "isaac/backend/parse.h" #include "isaac/backend/parse.h"
#include "isaac/exception/operation_not_supported.h" #include "isaac/exception/operation_not_supported.h"
@@ -17,6 +17,8 @@
namespace isaac namespace isaac
{ {
namespace templates
{
base::parameters_type::parameters_type(unsigned int _simd_width, int_t _local_size_1, int_t _local_size_2, int_t _num_kernels) : simd_width(_simd_width), local_size_0(_local_size_1), local_size_1(_local_size_2), num_kernels(_num_kernels) base::parameters_type::parameters_type(unsigned int _simd_width, int_t _local_size_1, int_t _local_size_2, int_t _num_kernels) : simd_width(_simd_width), local_size_0(_local_size_1), local_size_1(_local_size_2), num_kernels(_num_kernels)
{ } { }
@@ -102,12 +104,12 @@ void base::map_functor::operator()(isaac::array_expression const & array_express
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&array_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&array_expression, root_idx, &mapping_)));
else if (root_node.op.type==OPERATOR_MATRIX_COLUMN_TYPE) else if (root_node.op.type==OPERATOR_MATRIX_COLUMN_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&array_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&array_expression, root_idx, &mapping_)));
else if (detail::is_scalar_reduction(root_node)) else if (detail::is_scalar_dot(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_scalar_reduction>(&array_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_scalar_dot>(&array_expression, root_idx, &mapping_)));
else if (detail::is_vector_reduction(root_node)) else if (detail::is_vector_dot(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mreduction>(&array_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_gemv>(&array_expression, root_idx, &mapping_)));
else if (root_node.op.type_family == OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY) else if (root_node.op.type_family == OPERATOR_GEMM_TYPE_FAMILY)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mproduct>(&array_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_gemm>(&array_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_REPEAT_TYPE) else if (root_node.op.type == OPERATOR_REPEAT_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&array_expression, root_idx, &mapping_))); mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&array_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE) else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
@@ -198,7 +200,7 @@ void base::set_arguments_functor::operator()(isaac::array_expression const & arr
set_arguments(root_node.rhs); set_arguments(root_node.rhs);
} }
void base::compute_reduction(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op) void base::compute_dot(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op)
{ {
if (detail::is_elementwise_function(op)) if (detail::is_elementwise_function(op))
os << acc << "=" << evaluate(op.type) << "(" << acc << "," << cur << ");" << std::endl; os << acc << "=" << evaluate(op.type) << "(" << acc << "," << cur << ");" << std::endl;
@@ -206,7 +208,7 @@ void base::compute_reduction(kernel_generation_stream & os, std::string acc, std
os << acc << "= (" << acc << ")" << evaluate(op.type) << "(" << cur << ");" << std::endl; os << acc << "= (" << acc << ")" << evaluate(op.type) << "(" << cur << ");" << std::endl;
} }
void base::compute_index_reduction(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op) void base::compute_index_dot(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op)
{ {
// os << acc << " = " << cur_value << ">" << acc_value << "?" << cur << ":" << acc << ";" << std::endl; // os << acc << " = " << cur_value << ">" << acc_value << "?" << cur << ":" << acc << ";" << std::endl;
os << acc << "= select(" << acc << "," << cur << "," << cur_value << ">" << acc_value << ");" << std::endl; os << acc << "= select(" << acc << "," << cur << "," << cur_value << ">" << acc_value << ");" << std::endl;
@@ -259,7 +261,7 @@ std::string base::neutral_element(op_element const & op, driver::backend_type ba
case OPERATOR_ELEMENT_MIN_TYPE : return INF; case OPERATOR_ELEMENT_MIN_TYPE : return INF;
case OPERATOR_ELEMENT_ARGMIN_TYPE : return INF; case OPERATOR_ELEMENT_ARGMIN_TYPE : return INF;
default: throw operation_not_supported_exception("Unsupported reduction operator : no neutral element known"); default: throw operation_not_supported_exception("Unsupported dot operator : no neutral element known");
} }
} }
@@ -399,14 +401,14 @@ std::pair<int_t, int_t> base::matrix_size(array_expression::node const & node)
return std::make_pair(node.lhs.array->shape()[0],node.lhs.array->shape()[1]); return std::make_pair(node.lhs.array->shape()[0],node.lhs.array->shape()[1]);
} }
bool base::is_reduction(array_expression::node const & node) bool base::is_dot(array_expression::node const & node)
{ {
return node.op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY || node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY; || node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY;
} }
bool base::is_index_reduction(op_element const & op) bool base::is_index_dot(op_element const & op)
{ {
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
|| op.type==OPERATOR_ELEMENT_ARGMAX_TYPE || op.type==OPERATOR_ELEMENT_ARGMAX_TYPE
@@ -566,10 +568,11 @@ int base_impl<TType, PType>::is_invalid(expressions_tuple const & expressions, d
return is_invalid_impl(device, expressions); return is_invalid_impl(device, expressions);
} }
template class base_impl<vaxpy, vaxpy_parameters>; template class base_impl<axpy, axpy_parameters>;
template class base_impl<reduction, reduction_parameters>; template class base_impl<dot, dot_parameters>;
template class base_impl<maxpy, maxpy_parameters>; template class base_impl<ger, ger_parameters>;
template class base_impl<mreduction, mreduction_parameters>; template class base_impl<gemv, gemv_parameters>;
template class base_impl<mproduct, mproduct_parameters>; template class base_impl<gemm, gemm_parameters>;
} }
}

View File

@@ -1,5 +1,5 @@
#include <iostream> #include <iostream>
#include "isaac/backend/templates/reduction.h" #include "isaac/backend/templates/dot.h"
#include <CL/cl.hpp> #include <CL/cl.hpp>
#include "isaac/tools/to_string.hpp" #include "isaac/tools/to_string.hpp"
#include "isaac/tools/make_map.hpp" #include "isaac/tools/make_map.hpp"
@@ -7,13 +7,14 @@
#include "isaac/backend/keywords.h" #include "isaac/backend/keywords.h"
namespace isaac namespace isaac
{ {
namespace templates
reduction_parameters::reduction_parameters(unsigned int _simd_width, {
dot_parameters::dot_parameters(unsigned int _simd_width,
unsigned int _group_size, unsigned int _num_groups, unsigned int _group_size, unsigned int _num_groups,
fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _group_size, 1, 2), num_groups(_num_groups), fetching_policy(_fetching_policy) fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _group_size, 1, 2), num_groups(_num_groups), fetching_policy(_fetching_policy)
{ } { }
unsigned int reduction::lmem_usage(expressions_tuple const & expressions) const unsigned int dot::lmem_usage(expressions_tuple const & expressions) const
{ {
unsigned int res = 0; unsigned int res = 0;
for(const auto & elem : expressions.data()) for(const auto & elem : expressions.data())
@@ -24,14 +25,14 @@ unsigned int reduction::lmem_usage(expressions_tuple const & expressions) const
return res; return res;
} }
int reduction::is_invalid_impl(driver::Device const &, expressions_tuple const &) const int dot::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
{ {
if (p_.fetching_policy==FETCH_FROM_LOCAL) if (p_.fetching_policy==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE; return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID; return TEMPLATE_VALID;
} }
inline void reduction::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_reduction*> exprs, inline void dot::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_dot*> exprs,
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const
{ {
stream << "#pragma unroll" << std::endl; stream << "#pragma unroll" << std::endl;
@@ -44,26 +45,26 @@ inline void reduction::reduce_1d_local_memory(kernel_generation_stream & stream,
stream.inc_tab(); stream.inc_tab();
for (auto & expr : exprs) for (auto & expr : exprs)
if (expr->is_index_reduction()) if (expr->is_index_dot())
compute_index_reduction(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]") compute_index_dot(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]")
, expr->process(buf_value_str+"[lid]"), expr->process(buf_value_str+"[lid+stride]"), , expr->process(buf_value_str+"[lid]"), expr->process(buf_value_str+"[lid+stride]"),
expr->root_op()); expr->root_op());
else else
compute_reduction(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]"), expr->root_op()); compute_dot(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]"), expr->root_op());
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
} }
std::string reduction::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const std::string dot::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
{ {
kernel_generation_stream stream; kernel_generation_stream stream;
std::vector<mapped_scalar_reduction*> exprs; std::vector<mapped_scalar_dot*> exprs;
for (const auto & mapping : mappings) for (const auto & mapping : mappings)
for (mapping_type::const_iterator iit = mapping.begin(); iit != mapping.end(); ++iit) for (mapping_type::const_iterator iit = mapping.begin(); iit != mapping.end(); ++iit)
if (mapped_scalar_reduction * p = dynamic_cast<mapped_scalar_reduction*>(iit->second.get())) if (mapped_scalar_dot * p = dynamic_cast<mapped_scalar_dot*>(iit->second.get()))
exprs.push_back(p); exprs.push_back(p);
std::size_t N = exprs.size(); std::size_t N = exprs.size();
driver::backend_type backend = device.backend(); driver::backend_type backend = device.backend();
@@ -73,7 +74,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
for (unsigned int k = 0; k < N; ++k) for (unsigned int k = 0; k < N; ++k)
{ {
std::string numeric_type = numeric_type_to_string(lhs_most(exprs[k]->array_expression().tree(), exprs[k]->array_expression().root()).lhs.dtype); std::string numeric_type = numeric_type_to_string(lhs_most(exprs[k]->array_expression().tree(), exprs[k]->array_expression().root()).lhs.dtype);
if (exprs[k]->is_index_reduction()) if (exprs[k]->is_index_dot())
{ {
arguments += exprs[k]->process(Global(backend).get() + " unsigned int* #name_temp, "); arguments += exprs[k]->process(Global(backend).get() + " unsigned int* #name_temp, ");
arguments += exprs[k]->process(Global(backend).get() + " " + tools::to_string(numeric_type) + "* #name_temp_value, "); arguments += exprs[k]->process(Global(backend).get() + " " + tools::to_string(numeric_type) + "* #name_temp_value, ");
@@ -112,7 +113,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
for (unsigned int k = 0; k < N; ++k) for (unsigned int k = 0; k < N; ++k)
{ {
if (exprs[k]->is_index_reduction()) if (exprs[k]->is_index_dot())
{ {
stream << exprs[k]->process(Local(backend).get() + " #scalartype #name_buf_value[" + tools::to_string(p_.local_size_0) + "];") << std::endl; stream << exprs[k]->process(Local(backend).get() + " #scalartype #name_buf_value[" + tools::to_string(p_.local_size_0) + "];") << std::endl;
stream << exprs[k]->process("#scalartype #name_acc_value = " + neutral_element(exprs[k]->root_op(), backend, "#scalartype") + ";") << std::endl; stream << exprs[k]->process("#scalartype #name_acc_value = " + neutral_element(exprs[k]->root_op(), backend, "#scalartype") + ";") << std::endl;
@@ -156,11 +157,11 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
accessors["matrix_diag"] = str[a]; accessors["matrix_diag"] = str[a];
accessors["array0"] = "#namereg"; accessors["array0"] = "#namereg";
std::string value = elem->evaluate_recursive(LHS_NODE_TYPE, accessors); std::string value = elem->evaluate_recursive(LHS_NODE_TYPE, accessors);
if (elem->is_index_reduction()) if (elem->is_index_dot())
compute_index_reduction(stream, elem->process("#name_acc"), "i*" + tools::to_string(simd_width) + "+" compute_index_dot(stream, elem->process("#name_acc"), "i*" + tools::to_string(simd_width) + "+"
+ tools::to_string(a), elem->process("#name_acc_value"), value,elem->root_op()); + tools::to_string(a), elem->process("#name_acc_value"), value,elem->root_op());
else else
compute_reduction(stream, elem->process("#name_acc"), value,elem->root_op()); compute_dot(stream, elem->process("#name_acc"), value,elem->root_op());
} }
} }
}); });
@@ -168,7 +169,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
//Fills local memory //Fills local memory
for (unsigned int k = 0; k < N; ++k) for (unsigned int k = 0; k < N; ++k)
{ {
if (exprs[k]->is_index_reduction()) if (exprs[k]->is_index_dot())
stream << exprs[k]->process("#name_buf_value[lid] = #name_acc_value;") << std::endl; stream << exprs[k]->process("#name_buf_value[lid] = #name_acc_value;") << std::endl;
stream << exprs[k]->process("#name_buf[lid] = #name_acc;") << std::endl; stream << exprs[k]->process("#name_buf[lid] = #name_acc;") << std::endl;
} }
@@ -182,7 +183,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
stream.inc_tab(); stream.inc_tab();
for (unsigned int k = 0; k < N; ++k) for (unsigned int k = 0; k < N; ++k)
{ {
if (exprs[k]->is_index_reduction()) if (exprs[k]->is_index_dot())
stream << exprs[k]->process("#name_temp_value[gpid] = #name_buf_value[0];") << std::endl; stream << exprs[k]->process("#name_temp_value[gpid] = #name_buf_value[0];") << std::endl;
stream << exprs[k]->process("#name_temp[gpid] = #name_buf[0];") << std::endl; stream << exprs[k]->process("#name_temp[gpid] = #name_buf[0];") << std::endl;
} }
@@ -205,9 +206,9 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl; stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl;
stream << "unsigned int lsize = " <<LocalSize0(backend) << ";" << std::endl; stream << "unsigned int lsize = " <<LocalSize0(backend) << ";" << std::endl;
for (mapped_scalar_reduction* e: exprs) for (mapped_scalar_dot* e: exprs)
{ {
if (e->is_index_reduction()) if (e->is_index_dot())
{ {
stream << e->process(Local(backend).get() + " unsigned int #name_buf[" + tools::to_string(p_.local_size_0) + "];"); stream << e->process(Local(backend).get() + " unsigned int #name_buf[" + tools::to_string(p_.local_size_0) + "];");
stream << e->process("unsigned int #name_acc = 0;") << std::endl; stream << e->process("unsigned int #name_acc = 0;") << std::endl;
@@ -224,18 +225,18 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
stream << "for(unsigned int i = lid; i < " << p_.num_groups << "; i += lsize)" << std::endl; stream << "for(unsigned int i = lid; i < " << p_.num_groups << "; i += lsize)" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
for (mapped_scalar_reduction* e: exprs) for (mapped_scalar_dot* e: exprs)
if (e->is_index_reduction()) if (e->is_index_dot())
compute_index_reduction(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->process("#name_acc_value"),e->process("#name_temp_value[i]"),e->root_op()); compute_index_dot(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->process("#name_acc_value"),e->process("#name_temp_value[i]"),e->root_op());
else else
compute_reduction(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->root_op()); compute_dot(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->root_op());
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
for (unsigned int k = 0; k < N; ++k) for (unsigned int k = 0; k < N; ++k)
{ {
if (exprs[k]->is_index_reduction()) if (exprs[k]->is_index_dot())
stream << exprs[k]->process("#name_buf_value[lid] = #name_acc_value;") << std::endl; stream << exprs[k]->process("#name_buf_value[lid] = #name_acc_value;") << std::endl;
stream << exprs[k]->process("#name_buf[lid] = #name_acc;") << std::endl; stream << exprs[k]->process("#name_buf[lid] = #name_acc;") << std::endl;
} }
@@ -248,7 +249,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
std::map<std::string, std::string> accessors; std::map<std::string, std::string> accessors;
accessors["scalar_reduction"] = "#name_buf[0]"; accessors["scalar_dot"] = "#name_buf[0]";
accessors["array0"] = "#pointer[#start]"; accessors["array0"] = "#pointer[#start]";
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings); evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
stream.dec_tab(); stream.dec_tab();
@@ -260,23 +261,23 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
return stream.str(); return stream.str();
} }
reduction::reduction(reduction::parameters_type const & parameters, dot::dot(dot::parameters_type const & parameters,
binding_policy_t binding) : base_impl<reduction, reduction_parameters>(parameters, binding) binding_policy_t binding) : base_impl<dot, dot_parameters>(parameters, binding)
{ } { }
reduction::reduction(unsigned int simd, unsigned int ls, unsigned int ng, dot::dot(unsigned int simd, unsigned int ls, unsigned int ng,
fetching_policy_type fetch, binding_policy_t bind): fetching_policy_type fetch, binding_policy_t bind):
base_impl<reduction, reduction_parameters>(reduction_parameters(simd,ls,ng,fetch), bind) base_impl<dot, dot_parameters>(dot_parameters(simd,ls,ng,fetch), bind)
{} {}
std::vector<int_t> reduction::input_sizes(expressions_tuple const & expressions) const std::vector<int_t> dot::input_sizes(expressions_tuple const & expressions) const
{ {
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *(expressions.data().front()), false); std::vector<size_t> dots_idx = filter_nodes(&is_dot, *(expressions.data().front()), false);
int_t N = vector_size(lhs_most(expressions.data().front()->tree(), reductions_idx[0])); int_t N = vector_size(lhs_most(expressions.data().front()->tree(), dots_idx[0]));
return tools::make_vector<int_t>() << N; return tools::make_vector<int_t>() << N;
} }
void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller) void dot::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
{ {
expressions_tuple const & expressions = controller.x(); expressions_tuple const & expressions = controller.x();
@@ -290,12 +291,12 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
return; return;
} }
std::vector<array_expression::node const *> reductions; std::vector<array_expression::node const *> dots;
for (const auto & elem : expressions.data()) for (const auto & elem : expressions.data())
{ {
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *elem, false); std::vector<size_t> dots_idx = filter_nodes(&is_dot, *elem, false);
for (auto & reductions_idx_itt : reductions_idx) for (auto & dots_idx_itt : dots_idx)
reductions.push_back(&(elem)->tree()[reductions_idx_itt]); dots.push_back(&(elem)->tree()[dots_idx_itt]);
} }
//Kernel //Kernel
@@ -321,9 +322,9 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
//Temporary buffers //Temporary buffers
unsigned int i = 0; unsigned int i = 0;
unsigned int j = 0; unsigned int j = 0;
for (std::vector<array_expression::node const *>::const_iterator it = reductions.begin(); it != reductions.end(); ++it) for (std::vector<array_expression::node const *>::const_iterator it = dots.begin(); it != dots.end(); ++it)
{ {
if (is_index_reduction((*it)->op)) if (is_index_dot((*it)->op))
{ {
if (tmpidx_.size() <= j) if (tmpidx_.size() <= j)
tmpidx_.push_back(driver::Buffer(context, p_.num_groups*4)); tmpidx_.push_back(driver::Buffer(context, p_.num_groups*4));
@@ -343,3 +344,4 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
} }
} }
}

View File

@@ -1,5 +1,5 @@
#include "isaac/array.h" #include "isaac/array.h"
#include "isaac/backend/templates/mproduct.h" #include "isaac/backend/templates/gemm.h"
#include "isaac/backend/keywords.h" #include "isaac/backend/keywords.h"
#include "isaac/model/model.h" #include "isaac/model/model.h"
#include "isaac/symbolic/preset.h" #include "isaac/symbolic/preset.h"
@@ -10,8 +10,10 @@
namespace isaac namespace isaac
{ {
namespace templates
{
mproduct_parameters::mproduct_parameters(unsigned int simd_width gemm_parameters::gemm_parameters(unsigned int simd_width
, int_t local_size_0, int_t KL, int_t local_size_1, int_t D , int_t local_size_0, int_t KL, int_t local_size_1, int_t D
, int_t ms, int_t ks, int_t ns , int_t ms, int_t ks, int_t ns
, fetching_policy_type A_fetching_policy, fetching_policy_type B_fetching_policy , fetching_policy_type A_fetching_policy, fetching_policy_type B_fetching_policy
@@ -21,7 +23,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
mL(ms*local_size_0), nL(ns*local_size_1){} mL(ms*local_size_0), nL(ns*local_size_1){}
unsigned int mproduct::lmem_usage(expressions_tuple const & expressions) const unsigned int gemm::lmem_usage(expressions_tuple const & expressions) const
{ {
isaac::array_expression const & array_expression = (*expressions.data().front()); isaac::array_expression const & array_expression = (*expressions.data().front());
numeric_type numeric_t = lhs_most(array_expression.tree(), array_expression.root()).lhs.dtype; numeric_type numeric_t = lhs_most(array_expression.tree(), array_expression.root()).lhs.dtype;
@@ -32,7 +34,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return N*size_of(numeric_t); return N*size_of(numeric_t);
} }
unsigned int mproduct::registers_usage(expressions_tuple const & expressions) const unsigned int gemm::registers_usage(expressions_tuple const & expressions) const
{ {
isaac::array_expression const & array_expression = (*expressions.data().front()); isaac::array_expression const & array_expression = (*expressions.data().front());
numeric_type numeric_t = lhs_most(array_expression.tree(), array_expression.root()).lhs.dtype; numeric_type numeric_t = lhs_most(array_expression.tree(), array_expression.root()).lhs.dtype;
@@ -41,7 +43,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return N*size_of(numeric_t); return N*size_of(numeric_t);
} }
int mproduct::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const
{ {
std::vector<int_t> MNK = input_sizes(expressions); std::vector<int_t> MNK = input_sizes(expressions);
int_t M = MNK[0]; int_t N = MNK[1]; int_t M = MNK[0]; int_t N = MNK[1];
@@ -95,7 +97,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return TEMPLATE_VALID; return TEMPLATE_VALID;
} }
std::string mproduct::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const &) const std::string gemm::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const &) const
{ {
using std::string; using std::string;
using tools::to_string; using tools::to_string;
@@ -437,7 +439,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
#undef VST0RE #undef VST0RE
} }
void mproduct::enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, void gemm::enqueue_block(driver::CommandQueue & /*queue*/, int_t M, int_t N, int_t K,
array const & A, array const & B, array const & C, array const & A, array const & B, array const & C,
value_scalar const & alpha, value_scalar const & beta, value_scalar const & alpha, value_scalar const & beta,
driver::Program & program, const char * suffix, execution_options_type const & options) driver::Program & program, const char * suffix, execution_options_type const & options)
@@ -516,7 +518,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
} }
} }
array mproduct::create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap) array gemm::create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap)
{ {
slice s0(s0_0, s0_1); slice s0(s0_0, s0_1);
slice s1(s1_0, s1_1); slice s1(s1_0, s1_1);
@@ -525,7 +527,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return array(M, s0, s1); return array(M, s0, s1);
} }
std::vector<int_t> mproduct::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments) const std::vector<int_t> gemm::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments) const
{ {
isaac::array_expression & array_expression = (*expressions.data().front()); isaac::array_expression & array_expression = (*expressions.data().front());
array_expression::container_type & array = array_expression.tree(); array_expression::container_type & array = array_expression.tree();
@@ -537,26 +539,26 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return {M, N, K}; return {M, N, K};
} }
mproduct::mproduct(mproduct_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<mproduct, mproduct_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds) gemm::gemm(gemm_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<gemm, gemm_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
{ {
if(A_trans_=='N' && B_trans_=='N') type_ = MATRIX_PRODUCT_NN_TYPE; if(A_trans_=='N' && B_trans_=='N') type_ = GEMM_NN_TYPE;
else if(A_trans_=='T' && B_trans_=='N') type_ = MATRIX_PRODUCT_TN_TYPE; else if(A_trans_=='T' && B_trans_=='N') type_ = GEMM_TN_TYPE;
else if(A_trans_=='N' && B_trans_=='T') type_ = MATRIX_PRODUCT_NT_TYPE; else if(A_trans_=='N' && B_trans_=='T') type_ = GEMM_NT_TYPE;
else if(A_trans_=='T' && B_trans_=='T') type_ = MATRIX_PRODUCT_TT_TYPE; else if(A_trans_=='T' && B_trans_=='T') type_ = GEMM_TT_TYPE;
else throw; else throw;
} }
std::vector<int_t> mproduct::input_sizes(expressions_tuple const & expressions) const std::vector<int_t> gemm::input_sizes(expressions_tuple const & expressions) const
{ {
symbolic::preset::gemm::args dummy; symbolic::preset::gemm::args dummy;
return infos(expressions, dummy); return infos(expressions, dummy);
} }
void mproduct::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr) void gemm::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
{ {
using namespace tools; using namespace tools;
mproduct & fallback = (mproduct&)fallback_base; gemm & fallback = (gemm&)fallback_base;
expressions_tuple const & expressions = ctr.x(); expressions_tuple const & expressions = ctr.x();
@@ -579,8 +581,6 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
int_t ldstrideA = pA->stride()[0]; int_t ldstrideA = pA->stride()[0];
int_t ldstrideB = pB->stride()[0]; int_t ldstrideB = pB->stride()[0];
int_t ldstrideC = pC->stride()[0]; int_t ldstrideC = pC->stride()[0];
int_t ldstartA = pA->start()[0];
int_t ldstartB = pB->start()[0];
numeric_type dtype = args.C->dtype; numeric_type dtype = args.C->dtype;
@@ -613,40 +613,41 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
} }
// //
mproduct_nn::mproduct_nn(unsigned int simd gemm_nn::gemm_nn(unsigned int simd
, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns , int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch , fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) : , int_t lfetch0, int_t lfetch1, bool check_bound) :
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N') gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N')
{ } { }
// //
mproduct_tn::mproduct_tn(unsigned int simd gemm_tn::gemm_tn(unsigned int simd
, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns , int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch , fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) : , int_t lfetch0, int_t lfetch1, bool check_bound) :
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N') gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N')
{ } { }
// //
mproduct_nt::mproduct_nt(unsigned int simd gemm_nt::gemm_nt(unsigned int simd
, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns , int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch , fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) : , int_t lfetch0, int_t lfetch1, bool check_bound) :
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T') gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T')
{ } { }
// //
mproduct_tt::mproduct_tt(unsigned int simd gemm_tt::gemm_tt(unsigned int simd
, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns , int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch , fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) : , int_t lfetch0, int_t lfetch1, bool check_bound) :
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T') gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T')
{ } { }
}
} }

View File

@@ -1,48 +1,50 @@
#include <iostream> #include <iostream>
#include "isaac/backend/stream.h" #include "isaac/backend/stream.h"
#include "isaac/backend/keywords.h" #include "isaac/backend/keywords.h"
#include "isaac/backend/templates/mreduction.h" #include "isaac/backend/templates/gemv.h"
#include "isaac/tools/to_string.hpp" #include "isaac/tools/to_string.hpp"
#include "isaac/tools/make_map.hpp" #include "isaac/tools/make_map.hpp"
#include "isaac/tools/make_vector.hpp" #include "isaac/tools/make_vector.hpp"
namespace isaac namespace isaac
{ {
namespace templates
{
mreduction_parameters::mreduction_parameters(unsigned int _simd_width, gemv_parameters::gemv_parameters(unsigned int _simd_width,
unsigned int _local_size_0, unsigned int _local_size_1, unsigned int _local_size_0, unsigned int _local_size_1,
unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetch_policy): base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1), unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetch_policy): base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1),
num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetch_policy(_fetch_policy) { } num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetch_policy(_fetch_policy) { }
int mreduction::is_invalid_impl(driver::Device const &, expressions_tuple const &) const int gemv::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
{ {
if(reduction_type_==REDUCE_ROWS && p_.simd_width>1) if(dot_type_==REDUCE_ROWS && p_.simd_width>1)
return TEMPLATE_INVALID_SIMD_WIDTH; return TEMPLATE_INVALID_SIMD_WIDTH;
if (p_.fetch_policy==FETCH_FROM_LOCAL) if (p_.fetch_policy==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE; return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID; return TEMPLATE_VALID;
} }
unsigned int mreduction::lmem_usage() const unsigned int gemv::lmem_usage() const
{ {
return (p_.local_size_0+1)*p_.local_size_1; return (p_.local_size_0+1)*p_.local_size_1;
} }
std::string mreduction::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const std::string gemv::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
{ {
using tools::to_string; using tools::to_string;
std::vector<mapped_mreduction*> reductions; std::vector<mapped_gemv*> dots;
expressions_tuple::data_type::const_iterator sit; expressions_tuple::data_type::const_iterator sit;
std::vector<mapping_type>::const_iterator mit; std::vector<mapping_type>::const_iterator mit;
for (mit = mappings.begin(), sit = expressions.data().begin(); mit != mappings.end(); ++mit, ++sit) for (mit = mappings.begin(), sit = expressions.data().begin(); mit != mappings.end(); ++mit, ++sit)
{ {
array_expression const & first_expression = *expressions.data().front(); array_expression const & first_expression = *expressions.data().front();
std::vector<size_t> idx = filter_nodes(&is_reduction, first_expression, false); std::vector<size_t> idx = filter_nodes(&is_dot, first_expression, false);
for (auto & elem : idx) for (auto & elem : idx)
reductions.push_back((mapped_mreduction*)(mit->at(mapping_key(elem, PARENT_NODE_TYPE)).get())); dots.push_back((mapped_gemv*)(mit->at(mapping_key(elem, PARENT_NODE_TYPE)).get()));
} }
kernel_generation_stream stream; kernel_generation_stream stream;
@@ -54,10 +56,10 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
strcat(name[1], suffix); strcat(name[1], suffix);
std::string arguments = _size_t + " M, " + _size_t + " N, " ; std::string arguments = _size_t + " M, " + _size_t + " N, " ;
for (const auto & e : reductions) for (const auto & e : dots)
{ {
std::string numeric_type = numeric_type_to_string(lhs_most(e->array_expression().tree(), e->array_expression().root()).lhs.dtype); std::string numeric_type = numeric_type_to_string(lhs_most(e->array_expression().tree(), e->array_expression().root()).lhs.dtype);
if (e->is_index_reduction()) if (e->is_index_dot())
{ {
arguments += e->process(Global(backend).get() + " unsigned int* #name_temp, "); arguments += e->process(Global(backend).get() + " unsigned int* #name_temp, ");
arguments += e->process(Global(backend).get() + " " + to_string(numeric_type) + "* #name_temp_value,"); arguments += e->process(Global(backend).get() + " " + to_string(numeric_type) + "* #name_temp_value,");
@@ -87,7 +89,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
unsigned int local_size_0_ld = p_.local_size_0; unsigned int local_size_0_ld = p_.local_size_0;
std::string local_size_0_ld_str = to_string(local_size_0_ld); std::string local_size_0_ld_str = to_string(local_size_0_ld);
for (const auto & e : reductions) for (const auto & e : dots)
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl; stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
stream << "" << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl; stream << "" << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl;
@@ -104,7 +106,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl; stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
stream.inc_tab(); stream.inc_tab();
for (const auto & e : reductions) for (const auto & e : dots)
stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl; stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl;
stream << "if (r < M)" << std::endl; stream << "if (r < M)" << std::endl;
@@ -116,10 +118,10 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
std::string data_type = append_width("#scalartype",simd_width); std::string data_type = append_width("#scalartype",simd_width);
for (const auto & e : reductions) for (const auto & e : dots)
{ {
std::map<std::string, std::string> accessors; std::map<std::string, std::string> accessors;
if(reduction_type_==REDUCE_COLUMNS) if(dot_type_==REDUCE_COLUMNS)
{ {
accessors["array2"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "c*#stride1", "#pointer + r*#ld", backend)+";"; accessors["array2"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "c*#stride1", "#pointer + r*#ld", backend)+";";
accessors["repeat"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "(c%#tuplearg0)*#stride", "#pointer + (r%#tuplearg1)*#stride ", backend)+";"; accessors["repeat"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "(c%#tuplearg0)*#stride", "#pointer + (r%#tuplearg1)*#stride ", backend)+";";
@@ -141,20 +143,20 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
str[a] = access_vector_type("#namereg",a); str[a] = access_vector_type("#namereg",a);
for (auto & elem : reductions) for (auto & elem : dots)
for (unsigned int a = 0; a < simd_width; ++a) for (unsigned int a = 0; a < simd_width; ++a)
{ {
std::string value = elem->evaluate_recursive(LHS_NODE_TYPE, {{"array2", str[a]}, {"repeat", str[a]}, {"array0", "#namereg"}}); std::string value = elem->evaluate_recursive(LHS_NODE_TYPE, {{"array2", str[a]}, {"repeat", str[a]}, {"array0", "#namereg"}});
if (elem->is_index_reduction()) if (elem->is_index_dot())
compute_index_reduction(stream, elem->process("#name_acc"), "c*"+to_string(simd_width) + to_string(a), elem->process("#name_acc_value"), value, elem->root_op()); compute_index_dot(stream, elem->process("#name_acc"), "c*"+to_string(simd_width) + to_string(a), elem->process("#name_acc_value"), value, elem->root_op());
else else
compute_reduction(stream, elem->process("#name_acc"), value,elem->root_op()); compute_dot(stream, elem->process("#name_acc"), value,elem->root_op());
} }
}); });
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
for (auto & expr : reductions) for (auto & expr : dots)
stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl; stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl;
stream << "#pragma unroll" << std::endl; stream << "#pragma unroll" << std::endl;
@@ -167,13 +169,13 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
for (auto & e : reductions) for (auto & e : dots)
if (e->is_index_reduction()) if (e->is_index_dot())
compute_index_reduction(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]") compute_index_dot(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
, e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]") , e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
, e->root_op()); , e->root_op());
else else
compute_reduction(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op()); compute_dot(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
@@ -188,15 +190,15 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
if(p_.num_groups_0==1) if(p_.num_groups_0==1)
{ {
std::map<std::string, std::string> accessors; std::map<std::string, std::string> accessors;
accessors["mreduction"] = "#name_buf[lid1*" + local_size_0_ld_str + "]"; accessors["gemv"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
accessors["array1"] = "#pointer[r*#stride]"; accessors["array1"] = "#pointer[r*#stride]";
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings); evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
} }
else else
{ {
for (mapped_reduction const * e : reductions) for (mapped_dot const * e : dots)
{ {
if (e->is_index_reduction()) if (e->is_index_dot())
stream << e->process("#name_temp_value[r + M*gpid0] = #name_buf_value[lid1*" + local_size_0_ld_str + "];") << std::endl; stream << e->process("#name_temp_value[r + M*gpid0] = #name_buf_value[lid1*" + local_size_0_ld_str + "];") << std::endl;
stream << e->process("#name_temp[r + M*gpid0] = #name_buf[lid1*" + local_size_0_ld_str + "];") << std::endl; stream << e->process("#name_temp[r + M*gpid0] = #name_buf[lid1*" + local_size_0_ld_str + "];") << std::endl;
} }
@@ -230,7 +232,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
{"array2", "#pointer += #start1 + #start2*#ld; " {"array2", "#pointer += #start1 + #start2*#ld; "
"#ld *= #nldstride; "}}, expressions, mappings); "#ld *= #nldstride; "}}, expressions, mappings);
for (const auto & e : reductions) for (const auto & e : dots)
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl; stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
stream << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl; stream << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl;
@@ -246,7 +248,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl; stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
stream.inc_tab(); stream.inc_tab();
for (const auto & e : reductions) for (const auto & e : dots)
stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl; stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl;
stream << "if (r < M)" << std::endl; stream << "if (r < M)" << std::endl;
@@ -256,8 +258,8 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "for(" << _size_t << " c = lid0; c < " << p_.num_groups_0 << "; c += lsize0){" << std::endl; stream << "for(" << _size_t << " c = lid0; c < " << p_.num_groups_0 << "; c += lsize0){" << std::endl;
stream.inc_tab(); stream.inc_tab();
for (mapped_reduction* e: reductions) for (mapped_dot* e: dots)
compute_reduction(stream, e->process("#name_acc"), e->process("#name_temp[r + M*c]"), e->root_op()); compute_dot(stream, e->process("#name_acc"), e->process("#name_temp[r + M*c]"), e->root_op());
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
@@ -266,7 +268,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
for (auto & expr : reductions) for (auto & expr : dots)
stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl; stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl;
stream << "#pragma unroll" << std::endl; stream << "#pragma unroll" << std::endl;
@@ -279,13 +281,13 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
for (auto & e : reductions) for (auto & e : dots)
if (e->is_index_reduction()) if (e->is_index_dot())
compute_index_reduction(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]") compute_index_dot(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
, e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]") , e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
, e->root_op()); , e->root_op());
else else
compute_reduction(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op()); compute_dot(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
@@ -299,7 +301,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream.inc_tab(); stream.inc_tab();
std::map<std::string, std::string> accessors; std::map<std::string, std::string> accessors;
accessors["mreduction"] = "#name_buf[lid1*" + local_size_0_ld_str + "]"; accessors["gemv"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
accessors["array1"] = "#pointer[r*#stride]"; accessors["array1"] = "#pointer[r*#stride]";
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings); evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
@@ -317,38 +319,38 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
return stream.str(); return stream.str();
} }
mreduction::mreduction(mreduction::parameters_type const & parameters, gemv::gemv(gemv::parameters_type const & parameters,
mreduction::reduction_type rtype, gemv::dot_type rtype,
binding_policy_t binding_policy) : binding_policy_t binding_policy) :
base_impl<mreduction, mreduction_parameters>(parameters, binding_policy), base_impl<gemv, gemv_parameters>(parameters, binding_policy),
reduction_type_(rtype){ } dot_type_(rtype){ }
std::vector<int_t> mreduction::input_sizes(expressions_tuple const & expressions) const std::vector<int_t> gemv::input_sizes(expressions_tuple const & expressions) const
{ {
array_expression const & first_expression = *expressions.data().front(); array_expression const & first_expression = *expressions.data().front();
std::vector<std::size_t> idx = filter_nodes(&is_reduction, first_expression, false); std::vector<std::size_t> idx = filter_nodes(&is_dot, first_expression, false);
std::pair<int_t, int_t> MN = matrix_size(lhs_most(first_expression.tree(), idx[0])); std::pair<int_t, int_t> MN = matrix_size(lhs_most(first_expression.tree(), idx[0]));
if(reduction_type_==REDUCE_COLUMNS) if(dot_type_==REDUCE_COLUMNS)
std::swap(MN.first,MN.second); std::swap(MN.first,MN.second);
return tools::make_vector<int_t>() << MN.first << MN.second; return tools::make_vector<int_t>() << MN.first << MN.second;
} }
void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller) void gemv::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
{ {
expressions_tuple const & expressions = controller.x(); expressions_tuple const & expressions = controller.x();
driver::Context const & context = expressions.context(); driver::Context const & context = expressions.context();
std::vector<int_t> MN = input_sizes(expressions); std::vector<int_t> MN = input_sizes(expressions);
std::vector<array_expression::node const *> reductions; std::vector<array_expression::node const *> dots;
for (const auto & e : expressions.data()) for (const auto & e : expressions.data())
{ {
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *e, false); std::vector<size_t> dots_idx = filter_nodes(&is_dot, *e, false);
for (auto & r : reductions_idx) for (auto & r : dots_idx)
reductions.push_back(&(e)->tree()[r]); dots.push_back(&(e)->tree()[r]);
} }
//Fallback //Fallback
if(reduction_type_==REDUCE_COLUMNS && p_.simd_width>1 && requires_fallback(expressions)) if(dot_type_==REDUCE_COLUMNS && p_.simd_width>1 && requires_fallback(expressions))
{ {
fallback.enqueue(queue, program, "fallback", fallback, controller); fallback.enqueue(queue, program, "fallback", fallback, controller);
return; return;
@@ -381,9 +383,9 @@ void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program
//Temporary buffers //Temporary buffers
unsigned int i = 0; unsigned int i = 0;
unsigned int j = 0; unsigned int j = 0;
for (auto const & r : reductions) for (auto const & r : dots)
{ {
if (is_index_reduction(r->op)) if (is_index_dot(r->op))
{ {
if (tmpidx.size() <= j) if (tmpidx.size() <= j)
tmpidx.push_back(driver::Buffer(context, p_.num_groups_0*M*4)); tmpidx.push_back(driver::Buffer(context, p_.num_groups_0*M*4));
@@ -405,24 +407,25 @@ void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program
controller.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]); controller.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
} }
mreduction_rows::mreduction_rows(mreduction_parameters const & parameters, gemv_n::gemv_n(gemv_parameters const & parameters,
binding_policy_t binding_policy): binding_policy_t binding_policy):
mreduction(parameters, REDUCE_ROWS, binding_policy){} gemv(parameters, REDUCE_ROWS, binding_policy){}
mreduction_rows::mreduction_rows(unsigned int simd, unsigned int ls1, unsigned int ls2, gemv_n::gemv_n(unsigned int simd, unsigned int ls1, unsigned int ls2,
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind): unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind):
mreduction(mreduction_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_ROWS, bind) gemv(gemv_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_ROWS, bind)
{} {}
mreduction_cols::mreduction_cols(mreduction::parameters_type const & parameters, gemv_t::gemv_t(gemv::parameters_type const & parameters,
binding_policy_t binding_policy): binding_policy_t binding_policy):
mreduction(parameters, REDUCE_COLUMNS, binding_policy){} gemv(parameters, REDUCE_COLUMNS, binding_policy){}
mreduction_cols::mreduction_cols(unsigned int simd, unsigned int ls1, unsigned int ls2, gemv_t::gemv_t(unsigned int simd, unsigned int ls1, unsigned int ls2,
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind): unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind):
mreduction(mreduction_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_COLUMNS, bind) gemv(gemv_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_COLUMNS, bind)
{} {}
} }
}

View File

@@ -1,4 +1,4 @@
#include "isaac/backend/templates/maxpy.h" #include "isaac/backend/templates/ger.h"
#include "isaac/tools/make_map.hpp" #include "isaac/tools/make_map.hpp"
#include "isaac/tools/make_vector.hpp" #include "isaac/tools/make_vector.hpp"
#include "isaac/symbolic/io.h" #include "isaac/symbolic/io.h"
@@ -7,15 +7,17 @@
namespace isaac namespace isaac
{ {
namespace templates
{
maxpy_parameters::maxpy_parameters(unsigned int _simd_width, ger_parameters::ger_parameters(unsigned int _simd_width,
unsigned int _local_size_0, unsigned int _local_size_1, unsigned int _local_size_0, unsigned int _local_size_1,
unsigned int _num_groups_0, unsigned int _num_groups_1, unsigned int _num_groups_0, unsigned int _num_groups_1,
fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1), num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetching_policy(_fetching_policy){ } fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1), num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetching_policy(_fetching_policy){ }
int maxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const int ger::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
{ {
if (p_.simd_width>1) if (p_.simd_width>1)
return TEMPLATE_INVALID_SIMD_WIDTH; return TEMPLATE_INVALID_SIMD_WIDTH;
@@ -24,7 +26,7 @@ int maxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) co
return TEMPLATE_VALID; return TEMPLATE_VALID;
} }
std::string maxpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const std::string ger::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
{ {
kernel_generation_stream stream; kernel_generation_stream stream;
std::string _size_t = size_type(device); std::string _size_t = size_type(device);
@@ -95,23 +97,23 @@ std::string maxpy::generate_impl(const char * suffix, expressions_tuple const &
return stream.str(); return stream.str();
} }
maxpy::maxpy(parameters_type const & parameters, binding_policy_t binding_policy) : ger::ger(parameters_type const & parameters, binding_policy_t binding_policy) :
base_impl<maxpy, maxpy_parameters>(parameters, binding_policy){ } base_impl<ger, ger_parameters>(parameters, binding_policy){ }
maxpy::maxpy(unsigned int simd, unsigned int ls1, unsigned int ls2, ger::ger(unsigned int simd, unsigned int ls1, unsigned int ls2,
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch,
binding_policy_t bind): binding_policy_t bind):
base_impl<maxpy, maxpy_parameters>(maxpy_parameters(simd, ls1, ls2, ng1, ng2, fetch), bind) base_impl<ger, ger_parameters>(ger_parameters(simd, ls1, ls2, ng1, ng2, fetch), bind)
{} {}
std::vector<int_t> maxpy::input_sizes(expressions_tuple const & expressions) const std::vector<int_t> ger::input_sizes(expressions_tuple const & expressions) const
{ {
isaac::array_expression const & array_expression = *(expressions.data().front()); isaac::array_expression const & array_expression = *(expressions.data().front());
std::pair<int_t, int_t> size = matrix_size(lhs_most(array_expression.tree(), array_expression.root())); std::pair<int_t, int_t> size = matrix_size(lhs_most(array_expression.tree(), array_expression.root()));
return tools::make_vector<int_t>() << size.first << size.second; return tools::make_vector<int_t>() << size.first << size.second;
} }
void maxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base &, controller<expressions_tuple> const & controller) void ger::enqueue(driver::CommandQueue & /*queue*/, driver::Program & program, const char * suffix, base &, controller<expressions_tuple> const & controller)
{ {
expressions_tuple const & expressions = controller.x(); expressions_tuple const & expressions = controller.x();
char name[32] = {"axpy"}; char name[32] = {"axpy"};
@@ -129,3 +131,4 @@ void maxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, con
} }
} }
}

View File

@@ -7,11 +7,11 @@
#include "rapidjson/document.h" #include "rapidjson/document.h"
#include "isaac/backend/parse.h" #include "isaac/backend/parse.h"
#include "isaac/backend/templates/vaxpy.h" #include "isaac/backend/templates/axpy.h"
#include "isaac/backend/templates/reduction.h" #include "isaac/backend/templates/dot.h"
#include "isaac/backend/templates/maxpy.h" #include "isaac/backend/templates/ger.h"
#include "isaac/backend/templates/mreduction.h" #include "isaac/backend/templates/gemv.h"
#include "isaac/backend/templates/mproduct.h" #include "isaac/backend/templates/gemm.h"
#include "isaac/exception/unknown_datatype.h" #include "isaac/exception/unknown_datatype.h"
#include "isaac/exception/operation_not_supported.h" #include "isaac/exception/operation_not_supported.h"
#include "isaac/model/model.h" #include "isaac/model/model.h"
@@ -86,12 +86,12 @@ driver::Program& model::init(controller<expressions_tuple> const & expressions)
return *program; return *program;
} }
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< tools::shared_ptr<base> > const & templates, driver::CommandQueue const & queue) : model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< tools::shared_ptr<templates::base> > const & templates, driver::CommandQueue const & queue) :
templates_(templates), fallback_(fallbacks[std::make_pair(etype, dtype)]), predictor_(new predictors::random_forest(predictor)), queue_(queue) templates_(templates), fallback_(fallbacks[std::make_pair(etype, dtype)]), predictor_(new predictors::random_forest(predictor)), queue_(queue)
{} {}
model::model(expression_type etype, numeric_type dtype, base const & tp, driver::CommandQueue const & queue) : templates_(1,tp.clone()), fallback_(fallbacks[std::make_pair(etype, dtype)]), queue_(queue) model::model(expression_type etype, numeric_type dtype, templates::base const & tp, driver::CommandQueue const & queue) : templates_(1,tp.clone()), fallback_(fallbacks[std::make_pair(etype, dtype)]), queue_(queue)
{} {}
void model::execute(controller<expressions_tuple> const & expr) void model::execute(controller<expressions_tuple> const & expr)
@@ -148,15 +148,15 @@ namespace detail
{ {
static expression_type get_expression_type(std::string const & name) static expression_type get_expression_type(std::string const & name)
{ {
if(name=="vaxpy") return VECTOR_AXPY_TYPE; if(name=="axpy") return AXPY_TYPE;
if(name=="dot") return REDUCTION_TYPE; if(name=="dot") return DOT_TYPE;
if(name=="maxpy") return MATRIX_AXPY_TYPE; if(name=="ger") return GER_TYPE;
if(name=="mreduction_rows") return ROW_WISE_REDUCTION_TYPE; if(name=="gemv_n") return GEMV_N_TYPE;
if(name=="mreduction_cols") return COL_WISE_REDUCTION_TYPE; if(name=="gemv_t") return GEMV_T_TYPE;
if(name=="mproduct_nn") return MATRIX_PRODUCT_NN_TYPE; if(name=="gemm_nn") return GEMM_NN_TYPE;
if(name=="mproduct_nt") return MATRIX_PRODUCT_NT_TYPE; if(name=="gemm_nt") return GEMM_NT_TYPE;
if(name=="mproduct_tn") return MATRIX_PRODUCT_TN_TYPE; if(name=="gemm_tn") return GEMM_TN_TYPE;
if(name=="mproduct_tt") return MATRIX_PRODUCT_TT_TYPE; if(name=="gemm_tt") return GEMM_TT_TYPE;
throw std::invalid_argument("Invalid expression: " + name); throw std::invalid_argument("Invalid expression: " + name);
} }
@@ -167,27 +167,27 @@ namespace detail
throw std::invalid_argument("Invalid datatype: " + name); throw std::invalid_argument("Invalid datatype: " + name);
} }
static tools::shared_ptr<base> create(std::string const & template_name, std::vector<int> const & a) static tools::shared_ptr<templates::base> create(std::string const & template_name, std::vector<int> const & a)
{ {
fetching_policy_type fetch[] = {FETCH_FROM_LOCAL, FETCH_FROM_GLOBAL_STRIDED, FETCH_FROM_GLOBAL_CONTIGUOUS}; templates::fetching_policy_type fetch[] = {templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_GLOBAL_STRIDED, templates::FETCH_FROM_GLOBAL_CONTIGUOUS};
if(template_name=="vaxpy") if(template_name=="axpy")
return tools::shared_ptr<base>(new vaxpy(a[0], a[1], a[2], fetch[a[3]])); return tools::shared_ptr<templates::base>(new templates::axpy(a[0], a[1], a[2], fetch[a[3]]));
else if(template_name=="dot") else if(template_name=="dot")
return tools::shared_ptr<base>(new reduction(a[0], a[1], a[2], fetch[a[3]])); return tools::shared_ptr<templates::base>(new templates::dot(a[0], a[1], a[2], fetch[a[3]]));
else if(template_name=="maxpy") else if(template_name=="ger")
return tools::shared_ptr<base>(new maxpy(a[0], a[1], a[2], a[3], a[4], fetch[a[5]])); return tools::shared_ptr<templates::base>(new templates::ger(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("mreduction_rows")!=std::string::npos) else if(template_name.find("gemv_n")!=std::string::npos)
return tools::shared_ptr<base>(new mreduction_rows(a[0], a[1], a[2], a[3], a[4], fetch[a[5]])); return tools::shared_ptr<templates::base>(new templates::gemv_n(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("mreduction_cols")!=std::string::npos) else if(template_name.find("gemv_t")!=std::string::npos)
return tools::shared_ptr<base>(new mreduction_cols(a[0], a[1], a[2], a[3], a[4], fetch[a[5]])); return tools::shared_ptr<templates::base>(new templates::gemv_t(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("mproduct_nn")!=std::string::npos) else if(template_name.find("gemm_nn")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11])); return tools::shared_ptr<templates::base>(new templates::gemm_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("mproduct_tn")!=std::string::npos) else if(template_name.find("gemm_tn")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11])); return tools::shared_ptr<templates::base>(new templates::gemm_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("mproduct_nt")!=std::string::npos) else if(template_name.find("gemm_nt")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11])); return tools::shared_ptr<templates::base>(new templates::gemm_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("mproduct_tt")!=std::string::npos) else if(template_name.find("gemm_tt")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11])); return tools::shared_ptr<templates::base>(new templates::gemm_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else else
throw std::invalid_argument("Invalid expression: " + template_name); throw std::invalid_argument("Invalid expression: " + template_name);
} }
@@ -207,7 +207,7 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
str.assign((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>()); str.assign((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
document.Parse<0>(str.c_str()); document.Parse<0>(str.c_str());
//Deserialize //Deserialize
std::vector<std::string> operations = {"vaxpy", "dot", "maxpy", "gemv_n", "gemv_t", "mproduct_nn", "mproduct_tn", "mproduct_nt", "mproduct_tt"}; std::vector<std::string> operations = {"axpy", "dot", "ger", "gemv_n", "gemv_t", "gemm_nn", "gemm_tn", "gemm_nt", "gemm_tt"};
std::vector<std::string> dtype = {"float32", "float64"}; std::vector<std::string> dtype = {"float32", "float64"};
for(auto & operation : operations) for(auto & operation : operations)
{ {
@@ -223,7 +223,7 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
numeric_type dtype = detail::get_dtype(elem); numeric_type dtype = detail::get_dtype(elem);
// Get profiles // Get profiles
std::vector<tools::shared_ptr<base> > templates; std::vector<tools::shared_ptr<templates::base> > templates;
js::Value const & profiles = document[opcstr][dtcstr]["profiles"]; js::Value const & profiles = document[opcstr][dtcstr]["profiles"];
for (js::SizeType id = 0 ; id < profiles.Size() ; ++id) for (js::SizeType id = 0 ; id < profiles.Size() ; ++id)
templates.push_back(detail::create(operation, tools::to_int_array<int>(profiles[id]))); templates.push_back(detail::create(operation, tools::to_int_array<int>(profiles[id])));
@@ -243,23 +243,22 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
} }
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<base> > init_fallback() std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > init_fallback()
{ {
typedef tools::shared_ptr<base> ptr_t; typedef tools::shared_ptr<templates::base> ptr_t;
std::map<std::pair<expression_type, numeric_type>, ptr_t > res; std::map<std::pair<expression_type, numeric_type>, ptr_t > res;
numeric_type types[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE}; numeric_type types[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
for(auto DTYPE : types) for(auto DTYPE : types)
{ {
res[std::make_pair(SCALAR_AXPY_TYPE, DTYPE)] = ptr_t(new vaxpy(1,64,128,FETCH_FROM_GLOBAL_STRIDED)); res[std::make_pair(AXPY_TYPE, DTYPE)] = ptr_t (new templates::axpy(1,64,128,templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(VECTOR_AXPY_TYPE, DTYPE)] = ptr_t (new vaxpy(1,64,128,FETCH_FROM_GLOBAL_STRIDED)); res[std::make_pair(DOT_TYPE, DTYPE)] = ptr_t(new templates::dot(1,64,128,templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(REDUCTION_TYPE, DTYPE)] = ptr_t(new reduction(1,64,128,FETCH_FROM_GLOBAL_STRIDED)); res[std::make_pair(GER_TYPE, DTYPE)] = ptr_t(new templates::ger(1,8,8,8,8,templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(MATRIX_AXPY_TYPE, DTYPE)] = ptr_t(new maxpy(1,8,8,8,8,FETCH_FROM_GLOBAL_STRIDED)); res[std::make_pair(GEMV_N_TYPE, DTYPE)] = ptr_t(new templates::gemv_n(1, 8, 8, 4, 16, templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(ROW_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new mreduction_rows(1, 8, 8, 4, 16, FETCH_FROM_GLOBAL_STRIDED)); res[std::make_pair(GEMV_T_TYPE, DTYPE)] = ptr_t(new templates::gemv_t(1, 8, 8, 64, 8, templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(COL_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new mreduction_cols(1, 8, 8, 64, 8, FETCH_FROM_GLOBAL_STRIDED)); res[std::make_pair(GEMM_NN_TYPE, DTYPE)] = ptr_t(new templates::gemm_nn(1, 8, 32, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_NN_TYPE, DTYPE)] = ptr_t(new mproduct_nn(1, 8, 32, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true)); res[std::make_pair(GEMM_TN_TYPE, DTYPE)] = ptr_t(new templates::gemm_tn(1, 8, 32, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TN_TYPE, DTYPE)] = ptr_t(new mproduct_tn(1, 8, 32, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true)); res[std::make_pair(GEMM_NT_TYPE, DTYPE)] = ptr_t(new templates::gemm_nt(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_NT_TYPE, DTYPE)] = ptr_t(new mproduct_nt(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true)); res[std::make_pair(GEMM_TT_TYPE, DTYPE)] = ptr_t(new templates::gemm_tt(1, 8, 32, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TT_TYPE, DTYPE)] = ptr_t(new mproduct_tt(1, 8, 32, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
} }
return res; return res;
} }
@@ -269,7 +268,7 @@ model_map_t init_models(driver::CommandQueue & queue)
{ {
model_map_t res; model_map_t res;
numeric_type dtypes[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE}; numeric_type dtypes[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
expression_type etypes[] = {SCALAR_AXPY_TYPE, VECTOR_AXPY_TYPE, REDUCTION_TYPE, MATRIX_AXPY_TYPE, ROW_WISE_REDUCTION_TYPE, COL_WISE_REDUCTION_TYPE, MATRIX_PRODUCT_NN_TYPE, MATRIX_PRODUCT_NT_TYPE, MATRIX_PRODUCT_TN_TYPE, MATRIX_PRODUCT_TT_TYPE}; expression_type etypes[] = {AXPY_TYPE, DOT_TYPE, GER_TYPE, GEMV_N_TYPE, GEMV_T_TYPE, GEMM_NN_TYPE, GEMM_NT_TYPE, GEMM_TN_TYPE, GEMM_TT_TYPE};
for(numeric_type dtype: dtypes) for(numeric_type dtype: dtypes)
for(expression_type etype: etypes) for(expression_type etype: etypes)
@@ -288,7 +287,7 @@ model_map_t& models(driver::CommandQueue & queue)
return it->second; return it->second;
} }
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<base> > fallbacks = init_fallback(); std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > fallbacks = init_fallback();
std::map<driver::CommandQueue, model_map_t> models_; std::map<driver::CommandQueue, model_map_t> models_;
} }

View File

@@ -19,13 +19,13 @@ namespace isaac
inline bool is_mmprod(expression_type x) inline bool is_mmprod(expression_type x)
{ {
return x==MATRIX_PRODUCT_NN_TYPE || x==MATRIX_PRODUCT_NT_TYPE || return x==GEMM_NN_TYPE || x==GEMM_NT_TYPE ||
x==MATRIX_PRODUCT_TN_TYPE || x==MATRIX_PRODUCT_TT_TYPE; x==GEMM_TN_TYPE || x==GEMM_TT_TYPE;
} }
inline bool is_mvprod(expression_type x) inline bool is_mvprod(expression_type x)
{ {
return x==ROW_WISE_REDUCTION_TYPE || x==COL_WISE_REDUCTION_TYPE; return x==GEMV_N_TYPE || x==GEMV_T_TYPE;
} }
inline bool has_temporary_impl(op_element op, expression_type expression, expression_type other, bool is_first) inline bool has_temporary_impl(op_element op, expression_type expression, expression_type other, bool is_first)
@@ -36,27 +36,27 @@ namespace isaac
case OPERATOR_UNARY_TYPE_FAMILY: case OPERATOR_UNARY_TYPE_FAMILY:
case OPERATOR_BINARY_TYPE_FAMILY: case OPERATOR_BINARY_TYPE_FAMILY:
result |= is_mmprod(expression) result |= is_mmprod(expression)
|| (result |= expression==ROW_WISE_REDUCTION_TYPE && other==COL_WISE_REDUCTION_TYPE) || (result |= expression==GEMV_N_TYPE && other==GEMV_T_TYPE)
|| (result |= expression==COL_WISE_REDUCTION_TYPE && other==ROW_WISE_REDUCTION_TYPE); || (result |= expression==GEMV_T_TYPE && other==GEMV_N_TYPE);
break; break;
case OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY: case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
result |= is_mvprod(expression) result |= is_mvprod(expression)
|| expression==REDUCTION_TYPE; || expression==DOT_TYPE;
break; break;
case OPERATOR_ROWS_REDUCTION_TYPE_FAMILY: case OPERATOR_ROWS_DOT_TYPE_FAMILY:
result |= is_mmprod(expression) result |= is_mmprod(expression)
|| is_mvprod(expression) || is_mvprod(expression)
|| expression==REDUCTION_TYPE; || expression==DOT_TYPE;
break; break;
case OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY: case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
result |= is_mmprod(expression) result |= is_mmprod(expression)
|| is_mvprod(expression) || is_mvprod(expression)
|| expression==REDUCTION_TYPE; || expression==DOT_TYPE;
break; break;
case OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY: case OPERATOR_GEMM_TYPE_FAMILY:
result |= (is_mmprod(expression) && !is_first) result |= (is_mmprod(expression) && !is_first)
|| is_mvprod(expression) || is_mvprod(expression)
|| expression==REDUCTION_TYPE; || expression==DOT_TYPE;
break; break;
default: default:
break; break;
@@ -77,28 +77,28 @@ namespace isaac
{ {
case OPERATOR_UNARY_TYPE_FAMILY: case OPERATOR_UNARY_TYPE_FAMILY:
if(is_mmprod(left)) if(is_mmprod(left))
return MATRIX_AXPY_TYPE; return GER_TYPE;
return left; return left;
case OPERATOR_BINARY_TYPE_FAMILY: case OPERATOR_BINARY_TYPE_FAMILY:
if(left == ROW_WISE_REDUCTION_TYPE || right == ROW_WISE_REDUCTION_TYPE) return ROW_WISE_REDUCTION_TYPE; if(left == GEMV_N_TYPE || right == GEMV_N_TYPE) return GEMV_N_TYPE;
else if(left == COL_WISE_REDUCTION_TYPE || right == COL_WISE_REDUCTION_TYPE) return COL_WISE_REDUCTION_TYPE; else if(left == GEMV_T_TYPE || right == GEMV_T_TYPE) return GEMV_T_TYPE;
else if(left == REDUCTION_TYPE || right == REDUCTION_TYPE) return REDUCTION_TYPE; else if(left == DOT_TYPE || right == DOT_TYPE) return DOT_TYPE;
else if(left == VECTOR_AXPY_TYPE || right == VECTOR_AXPY_TYPE) return op.type==OPERATOR_OUTER_PROD_TYPE?MATRIX_AXPY_TYPE:VECTOR_AXPY_TYPE; else if(left == AXPY_TYPE || right == AXPY_TYPE) return op.type==OPERATOR_OUTER_PROD_TYPE?GER_TYPE:AXPY_TYPE;
else if(left == MATRIX_AXPY_TYPE || right == MATRIX_AXPY_TYPE) return MATRIX_AXPY_TYPE; else if(left == GER_TYPE || right == GER_TYPE) return GER_TYPE;
else if(is_mmprod(left) || is_mmprod(right)) return MATRIX_AXPY_TYPE; else if(is_mmprod(left) || is_mmprod(right)) return GER_TYPE;
std::cout << left << " " << right << std::endl; std::cout << left << " " << right << std::endl;
throw; throw;
case OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY: case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
return REDUCTION_TYPE; return DOT_TYPE;
case OPERATOR_ROWS_REDUCTION_TYPE_FAMILY: case OPERATOR_ROWS_DOT_TYPE_FAMILY:
return ROW_WISE_REDUCTION_TYPE; return GEMV_N_TYPE;
case OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY: case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
return COL_WISE_REDUCTION_TYPE; return GEMV_T_TYPE;
case OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY: case OPERATOR_GEMM_TYPE_FAMILY:
if(op.type==OPERATOR_MATRIX_PRODUCT_NN_TYPE) return MATRIX_PRODUCT_NN_TYPE; if(op.type==OPERATOR_GEMM_NN_TYPE) return GEMM_NN_TYPE;
else if(op.type==OPERATOR_MATRIX_PRODUCT_TN_TYPE) return MATRIX_PRODUCT_TN_TYPE; else if(op.type==OPERATOR_GEMM_TN_TYPE) return GEMM_TN_TYPE;
else if(op.type==OPERATOR_MATRIX_PRODUCT_NT_TYPE) return MATRIX_PRODUCT_NT_TYPE; else if(op.type==OPERATOR_GEMM_NT_TYPE) return GEMM_NT_TYPE;
else return MATRIX_PRODUCT_TT_TYPE; else return GEMM_TT_TYPE;
default: default:
throw; throw;
} }
@@ -119,9 +119,9 @@ namespace isaac
else if(node.lhs.subtype == DENSE_ARRAY_TYPE) else if(node.lhs.subtype == DENSE_ARRAY_TYPE)
{ {
if(node.lhs.array->nshape()==1) if(node.lhs.array->nshape()==1)
type_left = VECTOR_AXPY_TYPE; type_left = AXPY_TYPE;
else else
type_left = MATRIX_AXPY_TYPE; type_left = GER_TYPE;
} }
//Right //Right
@@ -131,9 +131,9 @@ namespace isaac
else if(node.rhs.subtype == DENSE_ARRAY_TYPE) else if(node.rhs.subtype == DENSE_ARRAY_TYPE)
{ {
if(node.rhs.array->nshape()==1) if(node.rhs.array->nshape()==1)
type_right = VECTOR_AXPY_TYPE; type_right = AXPY_TYPE;
else else
type_right = MATRIX_AXPY_TYPE; type_right = GER_TYPE;
} }
@@ -171,12 +171,10 @@ namespace isaac
//Init //Init
expression_type current_type; expression_type current_type;
if(root_save.lhs.array->nshape()==0) if(root_save.lhs.array->nshape()<=1)
current_type = SCALAR_AXPY_TYPE; current_type=AXPY_TYPE;
else if(root_save.lhs.array->nshape()==1)
current_type=VECTOR_AXPY_TYPE;
else else
current_type=MATRIX_AXPY_TYPE; current_type=GER_TYPE;
final_type = current_type; final_type = current_type;
/*----Parse required temporaries-----*/ /*----Parse required temporaries-----*/
@@ -193,18 +191,17 @@ namespace isaac
//Creates temporary //Creates temporary
tools::shared_ptr<array> tmp; tools::shared_ptr<array> tmp;
switch(it->first){ switch(it->first){
case SCALAR_AXPY_TYPE: case DOT_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
case REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break; case AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break; case GEMV_N_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break; case GEMV_T_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break; case GER_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break; case GEMM_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break; case GEMM_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break; case GEMM_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break; case GEMM_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
default: throw std::invalid_argument("Unrecognized operation"); default: throw std::invalid_argument("Unrecognized operation");
} }

View File

@@ -12,16 +12,16 @@ namespace preset
void gemm::handle_node(array_expression::container_type &tree, size_t rootidx, args & a) void gemm::handle_node(array_expression::container_type &tree, size_t rootidx, args & a)
{ {
//Matrix-Matrix product node //Matrix-Matrix product node
if(tree[rootidx].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY) if(tree[rootidx].op.type_family==OPERATOR_GEMM_TYPE_FAMILY)
{ {
if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY) a.A = &tree[rootidx].lhs; if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY) a.A = &tree[rootidx].lhs;
if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs; if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs;
switch(tree[rootidx].op.type) switch(tree[rootidx].op.type)
{ {
case OPERATOR_MATRIX_PRODUCT_NN_TYPE: a.type = MATRIX_PRODUCT_NN_TYPE; break; case OPERATOR_GEMM_NN_TYPE: a.type = GEMM_NN_TYPE; break;
case OPERATOR_MATRIX_PRODUCT_NT_TYPE: a.type = MATRIX_PRODUCT_NT_TYPE; break; case OPERATOR_GEMM_NT_TYPE: a.type = GEMM_NT_TYPE; break;
case OPERATOR_MATRIX_PRODUCT_TN_TYPE: a.type = MATRIX_PRODUCT_TN_TYPE; break; case OPERATOR_GEMM_TN_TYPE: a.type = GEMM_TN_TYPE; break;
case OPERATOR_MATRIX_PRODUCT_TT_TYPE: a.type = MATRIX_PRODUCT_TT_TYPE; break; case OPERATOR_GEMM_TT_TYPE: a.type = GEMM_TT_TYPE; break;
default: break; default: break;
} }
} }
@@ -31,7 +31,7 @@ void gemm::handle_node(array_expression::container_type &tree, size_t rootidx, a
{ {
//alpha*PROD //alpha*PROD
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY
&& tree[tree[rootidx].rhs.node_index].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY) && tree[tree[rootidx].rhs.node_index].op.type_family==OPERATOR_GEMM_TYPE_FAMILY)
{ {
a.alpha = &tree[rootidx].lhs; a.alpha = &tree[rootidx].lhs;
handle_node(tree, tree[rootidx].rhs.node_index, a); handle_node(tree, tree[rootidx].rhs.node_index, a);

View File

@@ -115,7 +115,7 @@ def main():
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")] include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
#Source files #Source files
src = 'src/lib/symbolic/preset.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/backend/templates/mreduction.cpp src/lib/backend/templates/reduction.cpp src/lib/backend/templates/mproduct.cpp src/lib/backend/templates/maxpy.cpp src/lib/backend/templates/base.cpp src/lib/backend/templates/vaxpy.cpp src/lib/backend/mapped_object.cpp src/lib/backend/stream.cpp src/lib/backend/parse.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/driver/backend.cpp src/lib/driver/device.cpp src/lib/driver/kernel.cpp src/lib/driver/buffer.cpp src/lib/driver/platform.cpp src/lib/driver/check.cpp src/lib/driver/program.cpp src/lib/driver/command_queue.cpp src/lib/driver/context.cpp src/lib/driver/event.cpp src/lib/driver/ndrange.cpp src/lib/driver/handle.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/wrap/clBLAS.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']] src = 'src/lib/array.cpp src/lib/value_scalar.cpp src/lib/wrap/clBLAS.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/expression.cpp src/lib/symbolic/io.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/context.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/event.cpp src/lib/driver/device.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/backend/parse.cpp src/lib/backend/mapped_object.cpp src/lib/backend/templates/axpy.cpp src/lib/backend/templates/ger.cpp src/lib/backend/templates/gemm.cpp src/lib/backend/templates/dot.cpp src/lib/backend/templates/gemv.cpp src/lib/backend/templates/base.cpp src/lib/backend/stream.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
boostsrc = 'external/boost/libs/' boostsrc = 'external/boost/libs/'
for s in ['numpy','python','smart_ptr','system','thread']: for s in ['numpy','python','smart_ptr','system','thread']:
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x] src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]

View File

@@ -20,6 +20,6 @@ BOOST_PYTHON_MODULE(_isaac)
export_driver(); export_driver();
export_exceptions(); export_exceptions();
export_model(); export_templates();
export_core(); export_core();
} }

View File

@@ -70,15 +70,15 @@ namespace tools
else else
name = bp::extract<std::string>(odtype.attr("__class__").attr("__name__"))(); name = bp::extract<std::string>(odtype.attr("__class__").attr("__name__"))();
if(name=="vaxpy") return isc::VECTOR_AXPY_TYPE; if(name=="axpy") return isc::AXPY_TYPE;
else if(name=="maxpy") return isc::MATRIX_AXPY_TYPE; else if(name=="ger") return isc::GER_TYPE;
else if(name=="reduction") return isc::REDUCTION_TYPE; else if(name=="dot") return isc::DOT_TYPE;
else if(name=="mreduction_rows") return isc::ROW_WISE_REDUCTION_TYPE; else if(name=="gemv_n") return isc::GEMV_N_TYPE;
else if(name=="mreduction_cols") return isc::COL_WISE_REDUCTION_TYPE; else if(name=="gemm_t") return isc::GEMV_T_TYPE;
else if(name=="mproduct_nn") return isc::MATRIX_PRODUCT_NN_TYPE; else if(name=="gemm_nn") return isc::GEMM_NN_TYPE;
else if(name=="mproduct_tn") return isc::MATRIX_PRODUCT_TN_TYPE; else if(name=="gemm_tn") return isc::GEMM_TN_TYPE;
else if(name=="mproduct_nt") return isc::MATRIX_PRODUCT_NT_TYPE; else if(name=="gemm_nt") return isc::GEMM_NT_TYPE;
else if(name=="mproduct_tt") return isc::MATRIX_PRODUCT_TT_TYPE; else if(name=="gemm_tt") return isc::GEMM_TT_TYPE;
else else
{ {
PyErr_SetString(PyExc_TypeError, "Template type not understood"); PyErr_SetString(PyExc_TypeError, "Template type not understood");

View File

@@ -1,3 +1,4 @@
#include "isaac/model/model.h"
#include "common.hpp" #include "common.hpp"
#include "core.h" #include "core.h"
@@ -79,6 +80,11 @@ unsigned int size(datatype<T> const & dt)
namespace detail namespace detail
{ {
std::shared_ptr<isc::model> construct_model(bp::object dtype, bp::object const & tp, isc::driver::CommandQueue & queue)
{
return std::shared_ptr<isc::model>(new isc::model(tools::extract_template_type(tp), tools::extract_dtype(dtype), (isaac::templates::base const &)bp::extract<isaac::templates::base>(tp), queue));
}
std::shared_ptr<isc::array> std::shared_ptr<isc::array>
ndarray_to_iscarray(const np::ndarray& array, const isc::driver::Context& ctx) ndarray_to_iscarray(const np::ndarray& array, const isc::driver::Context& ctx)
{ {
@@ -172,7 +178,6 @@ namespace detail
bp::throw_error_already_set(); bp::throw_error_already_set();
throw; throw;
} }
} }
} }
@@ -183,6 +188,10 @@ namespace detail
void export_core() void export_core()
{ {
bp::class_<isaac::model>("model", bp::no_init)
.def("__init__", bp::make_constructor(detail::construct_model))
.def("execute", &isc::model::execute);
bp::class_<isc::value_scalar>("value_scalar", bp::no_init) bp::class_<isc::value_scalar>("value_scalar", bp::no_init)
.add_property("dtype", &isc::value_scalar::dtype); .add_property("dtype", &isc::value_scalar::dtype);
@@ -206,15 +215,15 @@ void export_core()
#undef INSTANTIATE #undef INSTANTIATE
bp::enum_<isc::expression_type>("operations") bp::enum_<isc::expression_type>("operations")
MAP_ENUM(VECTOR_AXPY_TYPE, isc) MAP_ENUM(AXPY_TYPE, isc)
MAP_ENUM(MATRIX_AXPY_TYPE, isc) MAP_ENUM(GER_TYPE, isc)
MAP_ENUM(REDUCTION_TYPE, isc) MAP_ENUM(DOT_TYPE, isc)
MAP_ENUM(ROW_WISE_REDUCTION_TYPE, isc) MAP_ENUM(GEMV_N_TYPE, isc)
MAP_ENUM(COL_WISE_REDUCTION_TYPE, isc) MAP_ENUM(GEMV_T_TYPE, isc)
MAP_ENUM(VECTOR_AXPY_TYPE, isc) MAP_ENUM(GEMM_NN_TYPE, isc)
MAP_ENUM(VECTOR_AXPY_TYPE, isc) MAP_ENUM(GEMM_TN_TYPE, isc)
MAP_ENUM(VECTOR_AXPY_TYPE, isc) MAP_ENUM(GEMM_NT_TYPE, isc)
MAP_ENUM(VECTOR_AXPY_TYPE, isc); MAP_ENUM(GEMM_TT_TYPE, isc);
#define ADD_SCALAR_HANDLING(OP)\ #define ADD_SCALAR_HANDLING(OP)\
.def(bp::self OP int())\ .def(bp::self OP int())\

View File

@@ -1,70 +1,71 @@
#include "isaac/backend/templates/vaxpy.h" #include "isaac/backend/templates/axpy.h"
#include "isaac/backend/templates/maxpy.h" #include "isaac/backend/templates/ger.h"
#include "isaac/backend/templates/reduction.h" #include "isaac/backend/templates/dot.h"
#include "isaac/backend/templates/mreduction.h" #include "isaac/backend/templates/gemv.h"
#include "isaac/backend/templates/mproduct.h" #include "isaac/backend/templates/gemm.h"
#include "isaac/model/model.h" #include "isaac/model/model.h"
#include "common.hpp" #include "common.hpp"
#include "model.h" #include "model.h"
namespace tpt = isaac::templates;
namespace detail namespace detail
{ {
bp::list input_sizes(isaac::base & temp, isc::expressions_tuple const & tree) bp::list input_sizes(tpt::base & temp, isc::expressions_tuple const & tree)
{ {
std::vector<int> tmp = temp.input_sizes(tree); std::vector<int> tmp = temp.input_sizes(tree);
return tools::to_list(tmp.begin(), tmp.end()); return tools::to_list(tmp.begin(), tmp.end());
} }
std::shared_ptr<isc::model> construct_model(bp::object dtype, bp::object const & tp, isc::driver::CommandQueue & queue)
{
return std::shared_ptr<isc::model>(new isc::model(tools::extract_template_type(tp), tools::extract_dtype(dtype), (isc::base const &)bp::extract<isc::base>(tp), queue));
}
} }
void export_model() void export_templates()
{ {
bp::class_<isaac::model>("model", bp::no_init) bp::object templates_module(bp::handle<>(bp::borrowed(PyImport_AddModule("isaac.templates"))));
.def("__init__", bp::make_constructor(detail::construct_model)) bp::scope().attr("templates") = templates_module;
.def("execute", &isc::model::execute); bp::scope template_scope = templates_module;
bp::enum_<isaac::fetching_policy_type>
bp::enum_<tpt::fetching_policy_type>
("fetching_policy_type") ("fetching_policy_type")
.value("FETCH_FROM_LOCAL", isc::FETCH_FROM_LOCAL) .value("FETCH_FROM_LOCAL", tpt::FETCH_FROM_LOCAL)
.value("FETCH_FROM_GLOBAL_STRIDED", isc::FETCH_FROM_GLOBAL_STRIDED) .value("FETCH_FROM_GLOBAL_STRIDED", tpt::FETCH_FROM_GLOBAL_STRIDED)
.value("FETCH_FROM_GLOBAL_CONTIGUOUS", isc::FETCH_FROM_GLOBAL_CONTIGUOUS) .value("FETCH_FROM_GLOBAL_CONTIGUOUS", tpt::FETCH_FROM_GLOBAL_CONTIGUOUS)
; ;
//Base //Base
{ {
#define __PROP(name) .def_readonly(#name, &isaac::base::parameters_type::name) #define __PROP(name) .def_readonly(#name, &tpt::base::parameters_type::name)
bp::class_<isaac::base, boost::noncopyable>("base", bp::no_init) bp::class_<tpt::base, boost::noncopyable>("base", bp::no_init)
.def("lmem_usage", &isaac::base::lmem_usage) .def("lmem_usage", &tpt::base::lmem_usage)
.def("registers_usage", &isaac::base::registers_usage) .def("registers_usage", &tpt::base::registers_usage)
.def("is_invalid", &isaac::base::is_invalid) .def("is_invalid", &tpt::base::is_invalid)
.def("input_sizes", &detail::input_sizes) .def("input_sizes", &detail::input_sizes)
; ;
#undef __PROP #undef __PROP
} }
#define WRAP_BASE(name) bp::class_<isaac::base_impl<isaac::name, isaac::name::parameters_type>, bp::bases<isaac::base>, boost::noncopyable>(#name, bp::no_init); #define WRAP_BASE(name) bp::class_<tpt::base_impl<tpt::name, tpt::name::parameters_type>, bp::bases<tpt::base>, boost::noncopyable>(#name, bp::no_init);
#define WRAP_TEMPLATE(name, basename, ...) bp::class_<isaac::name, bp::bases<isaac::base_impl<isaac::basename, isaac::basename::parameters_type> > >(#name, bp::init<__VA_ARGS__>())\ #define WRAP_TEMPLATE(name, basename, ...) bp::class_<tpt::name, bp::bases<tpt::base_impl<tpt::basename, tpt::basename::parameters_type> > >(#name, bp::init<__VA_ARGS__>())\
.add_property("local_size_0", &isc::name::local_size_0)\ .add_property("local_size_0", &tpt::name::local_size_0)\
.add_property("local_size_1", &isc::name::local_size_1); .add_property("local_size_1", &tpt::name::local_size_1);
#define WRAP_SINGLE_TEMPLATE(name, ...) WRAP_BASE(name) WRAP_TEMPLATE(name, name, __VA_ARGS__) #define WRAP_SINGLE_TEMPLATE(name, ...) WRAP_BASE(name) WRAP_TEMPLATE(name, name, __VA_ARGS__)
//Vector AXPY //Vector AXPY
WRAP_SINGLE_TEMPLATE(vaxpy, uint, uint, uint, isaac::fetching_policy_type) WRAP_SINGLE_TEMPLATE(axpy, uint, uint, uint, tpt::fetching_policy_type)
WRAP_SINGLE_TEMPLATE(maxpy, uint, uint, uint, uint, uint, isaac::fetching_policy_type) WRAP_SINGLE_TEMPLATE(ger, uint, uint, uint, uint, uint, tpt::fetching_policy_type)
WRAP_SINGLE_TEMPLATE(reduction, uint, uint, uint, isaac::fetching_policy_type) WRAP_SINGLE_TEMPLATE(dot, uint, uint, uint, tpt::fetching_policy_type)
WRAP_BASE(mreduction) WRAP_BASE(gemv)
WRAP_TEMPLATE(mreduction_rows, mreduction, uint, uint, uint, uint, uint, isaac::fetching_policy_type) WRAP_TEMPLATE(gemv_n, gemv, uint, uint, uint, uint, uint, tpt::fetching_policy_type)
WRAP_TEMPLATE(mreduction_cols, mreduction, uint, uint, uint, uint, uint, isaac::fetching_policy_type) WRAP_TEMPLATE(gemv_t, gemv, uint, uint, uint, uint, uint, tpt::fetching_policy_type)
WRAP_BASE(mproduct) WRAP_BASE(gemm)
WRAP_TEMPLATE(mproduct_nn, mproduct, uint, uint, uint, uint, uint, uint, uint, uint, isaac::fetching_policy_type, isaac::fetching_policy_type, uint, uint) WRAP_TEMPLATE(gemm_nn, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint)
WRAP_TEMPLATE(mproduct_tn, mproduct, uint, uint, uint, uint, uint, uint, uint, uint, isaac::fetching_policy_type, isaac::fetching_policy_type, uint, uint) WRAP_TEMPLATE(gemm_tn, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint)
WRAP_TEMPLATE(mproduct_nt, mproduct, uint, uint, uint, uint, uint, uint, uint, uint, isaac::fetching_policy_type, isaac::fetching_policy_type, uint, uint) WRAP_TEMPLATE(gemm_nt, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint)
WRAP_TEMPLATE(mproduct_tt, mproduct, uint, uint, uint, uint, uint, uint, uint, uint, isaac::fetching_policy_type, isaac::fetching_policy_type, uint, uint) WRAP_TEMPLATE(gemm_tt, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint)
} }

View File

@@ -1,6 +1,6 @@
#ifndef ISAAC_PYTHON_MODEL_HPP #ifndef ISAAC_PYTHON_MODEL_HPP
#define ISAAC_PYTHON_MODEL_HPP #define ISAAC_PYTHON_MODEL_HPP
void export_model(); void export_templates();
#endif #endif

View File

@@ -4,7 +4,7 @@ if(CUDA_FOUND)
endif() endif()
get_property(ISAAC_PATH TARGET isaac PROPERTY LOCATION) get_property(ISAAC_PATH TARGET isaac PROPERTY LOCATION)
foreach(PROG maxpy vaxpy reduction mreduction mproduct) foreach(PROG axpy dot ger gemv gemm)
add_executable(${PROG}-test ${PROG}.cpp) add_executable(${PROG}-test ${PROG}.cpp)
add_test(${PROG} ${PROG}-test) add_test(${PROG} ${PROG}-test)
target_link_libraries(${PROG}-test isaac ${OPENCL_LIBRARIES}) target_link_libraries(${PROG}-test isaac ${OPENCL_LIBRARIES})

View File

@@ -14,10 +14,10 @@ from numpy import cumsum
import tools import tools
fetch_types = [isc.fetching_policy_type.FETCH_FROM_LOCAL, fetch_types = [isc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_CONTIGUOUS,
isc.fetching_policy_type.FETCH_FROM_LOCAL, isc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_STRIDED,
isc.fetching_policy_type.FETCH_FROM_LOCAL, isc.templates.fetching_policy_type.FETCH_FROM_LOCAL,
isc.fetching_policy_type.FETCH_FROM_LOCAL] isc.templates.fetching_policy_type.FETCH_FROM_LOCAL]
def exhaustive(template, sizes, context): def exhaustive(template, sizes, context):
tree, _ = tools.tree_of(template, sizes, context) tree, _ = tools.tree_of(template, sizes, context)
@@ -159,15 +159,15 @@ def is_local_optimum(parameters, template, sizes, context):
tree, _ = tools.tree_of(template, sizes, context) tree, _ = tools.tree_of(template, sizes, context)
genetic_infos = tools.genetic_infos_of(template) genetic_infos = tools.genetic_infos_of(template)
if issubclass(template, isc.vaxpy): if issubclass(template, isc.axpy):
sweep_over = [0,1,2] sweep_over = [0,1,2]
elif issubclass(template, isc.reduction): elif issubclass(template, isc.dot):
sweep_over = [0,1,2] sweep_over = [0,1,2]
elif issubclass(template, isc.maxpy): elif issubclass(template, isc.ger):
sweep_over = [0,1,2,3,4] sweep_over = [0,1,2,3,4]
elif issubclass(template, isc.mreduction): elif issubclass(template, isc.gemv):
sweep_over = [0,1,2,3,4] sweep_over = [0,1,2,3,4]
elif issubclass(template, isc.mproduct): elif issubclass(template, isc.gemm):
sweep_over = [1,2,3,4,5,7,10,11] sweep_over = [1,2,3,4,5,7,10,11]
#Evaluate the provided parameters guess #Evaluate the provided parameters guess

View File

@@ -36,30 +36,30 @@ def benchmark(template, setting, tree):
def tree_of(template, sizes, context): def tree_of(template, sizes, context):
if issubclass(template, isc.vaxpy): if issubclass(template, isc.templates.axpy):
N, = sizes N, = sizes
x = isc.empty(N, dtype=isc.float32, context=context) x = isc.empty(N, dtype=isc.float32, context=context)
y = isc.empty(N, dtype=isc.float32, context=context) y = isc.empty(N, dtype=isc.float32, context=context)
return x + y, (x, y) return x + y, (x, y)
elif issubclass(template, isc.reduction): elif issubclass(template, isc.templates.dot):
N, = sizes N, = sizes
x = isc.empty(N, context=context) x = isc.empty(N, context=context)
y = isc.empty(N, context=context) y = isc.empty(N, context=context)
return isc.dot(x, y), (x, y) return isc.dot(x, y), (x, y)
elif issubclass(template, isc.maxpy): elif issubclass(template, isc.templates.ger):
M, N = sizes M, N = sizes
A = isc.empty((M,N), context=context) A = isc.empty((M,N), context=context)
B = isc.empty((M,N), context=context) B = isc.empty((M,N), context=context)
return A + B, (A, B) return A + B, (A, B)
elif issubclass(template, isc.mreduction): elif issubclass(template, isc.templates.gemv):
T = template is isc.mreduction_cols T = template is isc.templates.gemv_t
M, N = sizes[::-1] if T else sizes M, N = sizes[::-1] if T else sizes
A = isc.empty((M,N), context=context) A = isc.empty((M,N), context=context)
x = isc.empty(N, context=context) x = isc.empty(N, context=context)
return isc.dot(A.T, x) if T else isc.dot(A, x), (A, x) return isc.dot(A.T, x) if T else isc.dot(A, x), (A, x)
elif issubclass(template, isc.mproduct): elif issubclass(template, isc.templates.gemm):
AT = template is isc.mproduct_tn or template is isc.mproduct_tt AT = template is isc.templates.gemm_tn or template is isc.templates.gemm_tt
BT = template is isc.mproduct_nt or template is isc.mproduct_tt BT = template is isc.templates.gemm_nt or template is isc.templates.gemm_tt
M, N, K = sizes M, N, K = sizes
A = isc.empty((K, M) if AT else (M, K), context=context) A = isc.empty((K, M) if AT else (M, K), context=context)
B = isc.empty((N, K) if BT else (K, N), context=context) B = isc.empty((N, K) if BT else (K, N), context=context)
@@ -68,35 +68,35 @@ def tree_of(template, sizes, context):
return isc.dot(AA, BB), (A, B) return isc.dot(AA, BB), (A, B)
def memory_footprint(template, sizes): def memory_footprint(template, sizes):
if issubclass(template, isc.vaxpy): if issubclass(template, isc.templates.axpy):
return 4*3*sizes[0]*1e-9 return 4*3*sizes[0]*1e-9
elif issubclass(template, isc.reduction): elif issubclass(template, isc.templates.dot):
return 4*2*sizes[0]*1e-9 return 4*2*sizes[0]*1e-9
elif issubclass(template, isc.maxpy): elif issubclass(template, isc.templates.ger):
return 4*3*sizes[0]*sizes[1]*1e-9 return 4*3*sizes[0]*sizes[1]*1e-9
elif issubclass(template, isc.mreduction): elif issubclass(template, isc.templates.gemv):
return 4*sizes[0]*sizes[1]*1e-9 return 4*sizes[0]*sizes[1]*1e-9
elif issubclass(template, isc.mproduct): elif issubclass(template, isc.templates.gemm):
return 4*(sizes[0]*sizes[1] + sizes[0]*sizes[2] + sizes[1]*sizes[2])*1e-9 return 4*(sizes[0]*sizes[1] + sizes[0]*sizes[2] + sizes[1]*sizes[2])*1e-9
def metric_of(template): def metric_of(template):
memory_bound = [isc.vaxpy, isc.reduction, isc.maxpy, isc.mreduction] memory_bound = [isc.templates.axpy, isc.templates.dot, isc.templates.ger, isc.templates.gemv]
compute_bound = [isc.mproduct] compute_bound = [isc.templates.gemm]
if any([issubclass(template, x) for x in memory_bound]): if any([issubclass(template, x) for x in memory_bound]):
return lambda sizes, t: memory_footprint(template, sizes)/t return lambda sizes, t: memory_footprint(template, sizes)/t
elif any([issubclass(template, x) for x in compute_bound]): elif any([issubclass(template, x) for x in compute_bound]):
return lambda sizes, t: 2*sizes[0]*sizes[1]*sizes[2]*1e-9/t return lambda sizes, t: 2*sizes[0]*sizes[1]*sizes[2]*1e-9/t
def genetic_infos_of(template): def genetic_infos_of(template):
if issubclass(template, isc.vaxpy): if issubclass(template, isc.templates.axpy):
return {'categorical': [3], 'nbits': [3,4,4,2] } return {'categorical': [3], 'nbits': [3,4,4,2] }
elif issubclass(template, isc.reduction): elif issubclass(template, isc.templates.dot):
return {'categorical': [3], 'nbits':[3,4,4,2]} return {'categorical': [3], 'nbits':[3,4,4,2]}
elif issubclass(template, isc.maxpy): elif issubclass(template, isc.templates.ger):
return {'categorical': [5], 'nbits': [3,3,3,3,4,2]} return {'categorical': [5], 'nbits': [3,3,3,3,4,2]}
elif issubclass(template, isc.mreduction): elif issubclass(template, isc.templates.gemv):
return {'categorical': [5], 'nbits': [3,3,3,3,4,2]} return {'categorical': [5], 'nbits': [3,3,3,3,4,2]}
elif issubclass(template, isc.mproduct): elif issubclass(template, isc.templates.gemm):
return {'categorical': [8,9], 'nbits': [3,3,3,3,3,2,2,2,2,2,3,3]} return {'categorical': [8,9], 'nbits': [3,3,3,3,3,2,2,2,2,2,3,3]}

View File

@@ -23,14 +23,14 @@ def tune(device, operation, json_path):
#List of size tuples to use #List of size tuples to use
sizes = {} sizes = {}
sizes[isc.vaxpy] = [(x,) for x in tools.expspace(1e3, 1e7, 4)] sizes[isc.templates.axpy] = [(x,) for x in tools.expspace(1e3, 1e7, 4)]
sizes[isc.mreduction_rows] = product(pow2range(4,17), pow2range(4,17)) sizes[isc.templates.gemv_n] = product(pow2range(4,17), pow2range(4,17))
sizes[isc.mreduction_cols] = isc.mreduction_rows sizes[isc.templates.gemv_t] = sizes[isc.templates.gemv_n]
sizes[isc.mproduct_nn] = product(pow2range(5, 10), pow2range(5, 10), pow2range(5, 10)) sizes[isc.templates.gemm_nn] = product(pow2range(5, 10), pow2range(5, 10), pow2range(5, 10))
sizes[isc.mproduct_nn] = [(128, 169, 1728)] sizes[isc.templates.gemm_nn] = [(128, 169, 1728)]
sizes[isc.mproduct_tn] = sizes[isc.mproduct_nn] sizes[isc.templates.gemm_tn] = sizes[isc.templates.gemm_nn]
sizes[isc.mproduct_nt] = sizes[isc.mproduct_nn] sizes[isc.templates.gemm_nt] = sizes[isc.templates.gemm_nn]
sizes[isc.mproduct_tt] = sizes[isc.mproduct_nn] sizes[isc.templates.gemm_tt] = sizes[isc.templates.gemm_nn]
sizes = unique(list(sizes[operation])) sizes = unique(list(sizes[operation]))
sizes = [x for x in sizes if 1e-4 <= tools.memory_footprint(operation, x) <= 1e-1] sizes = [x for x in sizes if 1e-4 <= tools.memory_footprint(operation, x) <= 1e-1]
@@ -123,9 +123,9 @@ def parse_arguments():
print("----------------") print("----------------")
operation = {'vaxpy': isc.vaxpy, 'dot': isc.reduction, operation = {'axpy': isc.templates.axpy, 'dot': isc.templates.dot,
'maxpy': isc.maxpy, 'gemv_n': isc.mreduction_rows, 'gemv_t': isc.mreduction_cols, 'ger': isc.templates.ger, 'gemv_n': isc.templates.gemv_n, 'gemv_t': isc.templates.gemv_t,
'gemm_nn': isc.mproduct_nn, 'gemm_tn': isc.mproduct_tn, 'gemm_nt': isc.mproduct_nt, 'gemm_tt':isc.mproduct_tt}[args.operation] 'gemm_nn': isc.templates.gemm_nn, 'gemm_tn': isc.templates.gemm_tn, 'gemm_nt': isc.templates.gemm_nt, 'gemm_tt':isc.templates.gemm_tt}[args.operation]
if not args.json: if not args.json:
json = tools.sanitize(device.name) + '.json' json = tools.sanitize(device.name) + '.json'
return (device, operation, json) return (device, operation, json)