Cleaning: Largely renamed templates to BLAS-like names
This commit is contained in:
@@ -219,15 +219,15 @@ array_expression norm(array_expression const &, unsigned int order = 2);
|
|||||||
|
|
||||||
array_expression repmat(array const &, int_t const & rep1, int_t const & rep2);
|
array_expression repmat(array const &, int_t const & rep1, int_t const & rep2);
|
||||||
|
|
||||||
#define ISAAC_DECLARE_REDUCTION(OPNAME) \
|
#define ISAAC_DECLARE_DOT(OPNAME) \
|
||||||
array_expression OPNAME(array const & M, int_t axis = -1);\
|
array_expression OPNAME(array const & M, int_t axis = -1);\
|
||||||
array_expression OPNAME(array_expression const & M, int_t axis = -1);
|
array_expression OPNAME(array_expression const & M, int_t axis = -1);
|
||||||
|
|
||||||
ISAAC_DECLARE_REDUCTION(sum)
|
ISAAC_DECLARE_DOT(sum)
|
||||||
ISAAC_DECLARE_REDUCTION(argmax)
|
ISAAC_DECLARE_DOT(argmax)
|
||||||
ISAAC_DECLARE_REDUCTION(max)
|
ISAAC_DECLARE_DOT(max)
|
||||||
ISAAC_DECLARE_REDUCTION(min)
|
ISAAC_DECLARE_DOT(min)
|
||||||
ISAAC_DECLARE_REDUCTION(argmin)
|
ISAAC_DECLARE_DOT(argmin)
|
||||||
|
|
||||||
array_expression eye(std::size_t, std::size_t, isaac::numeric_type, driver::Context context = driver::queues.default_context());
|
array_expression eye(std::size_t, std::size_t, isaac::numeric_type, driver::Context context = driver::queues.default_context());
|
||||||
array_expression zeros(std::size_t M, std::size_t N, numeric_type dtype, driver::Context context = driver::queues.default_context());
|
array_expression zeros(std::size_t M, std::size_t N, numeric_type dtype, driver::Context context = driver::queues.default_context());
|
||||||
|
@@ -84,46 +84,46 @@ protected:
|
|||||||
*
|
*
|
||||||
* Maps prod(matrix_expression, matrix_expression)
|
* Maps prod(matrix_expression, matrix_expression)
|
||||||
*/
|
*/
|
||||||
class mapped_mproduct : public mapped_object, public binary_leaf
|
class mapped_gemm : public mapped_object, public binary_leaf
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
mapped_mproduct(std::string const & scalartype, unsigned int id, node_info info);
|
mapped_gemm(std::string const & scalartype, unsigned int id, node_info info);
|
||||||
};
|
};
|
||||||
|
|
||||||
/** @brief Reduction
|
/** @brief Reduction
|
||||||
*
|
*
|
||||||
* Base class for mapping a reduction
|
* Base class for mapping a dot
|
||||||
*/
|
*/
|
||||||
class mapped_reduction : public mapped_object, public binary_leaf
|
class mapped_dot : public mapped_object, public binary_leaf
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
mapped_reduction(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key);
|
mapped_dot(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key);
|
||||||
|
|
||||||
int_t root_idx() const;
|
int_t root_idx() const;
|
||||||
isaac::array_expression const & array_expression() const;
|
isaac::array_expression const & array_expression() const;
|
||||||
array_expression::node root_node() const;
|
array_expression::node root_node() const;
|
||||||
bool is_index_reduction() const;
|
bool is_index_dot() const;
|
||||||
op_element root_op() const;
|
op_element root_op() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/** @brief Scalar reduction
|
/** @brief Scalar dot
|
||||||
*
|
*
|
||||||
* Maps a scalar reduction (max, min, argmax, inner_prod, etc..)
|
* Maps a scalar dot (max, min, argmax, inner_prod, etc..)
|
||||||
*/
|
*/
|
||||||
class mapped_scalar_reduction : public mapped_reduction
|
class mapped_scalar_dot : public mapped_dot
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
mapped_scalar_reduction(std::string const & scalartype, unsigned int id, node_info info);
|
mapped_scalar_dot(std::string const & scalartype, unsigned int id, node_info info);
|
||||||
};
|
};
|
||||||
|
|
||||||
/** @brief Vector reduction
|
/** @brief Vector dot
|
||||||
*
|
*
|
||||||
* Maps a row-wise reduction (max, min, argmax, matrix-vector product, etc..)
|
* Maps a row-wise dot (max, min, argmax, matrix-vector product, etc..)
|
||||||
*/
|
*/
|
||||||
class mapped_mreduction : public mapped_reduction
|
class mapped_gemv : public mapped_dot
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
mapped_mreduction(std::string const & scalartype, unsigned int id, node_info info);
|
mapped_gemv(std::string const & scalartype, unsigned int id, node_info info);
|
||||||
};
|
};
|
||||||
|
|
||||||
/** @brief Host scalar
|
/** @brief Host scalar
|
||||||
|
@@ -13,8 +13,8 @@ namespace detail
|
|||||||
{
|
{
|
||||||
|
|
||||||
bool is_node_leaf(op_element const & op);
|
bool is_node_leaf(op_element const & op);
|
||||||
bool is_scalar_reduction(array_expression::node const & node);
|
bool is_scalar_dot(array_expression::node const & node);
|
||||||
bool is_vector_reduction(array_expression::node const & node);
|
bool is_vector_dot(array_expression::node const & node);
|
||||||
bool is_assignment(op_element const & op);
|
bool is_assignment(op_element const & op);
|
||||||
bool is_elementwise_operator(op_element const & op);
|
bool is_elementwise_operator(op_element const & op);
|
||||||
bool is_elementwise_function(op_element const & op);
|
bool is_elementwise_function(op_element const & op);
|
||||||
|
@@ -5,27 +5,30 @@
|
|||||||
|
|
||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
namespace templates
|
||||||
|
{
|
||||||
|
|
||||||
class vaxpy_parameters : public base::parameters_type
|
class axpy_parameters : public base::parameters_type
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
vaxpy_parameters(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy);
|
axpy_parameters(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy);
|
||||||
unsigned int num_groups;
|
unsigned int num_groups;
|
||||||
fetching_policy_type fetching_policy;
|
fetching_policy_type fetching_policy;
|
||||||
};
|
};
|
||||||
|
|
||||||
class vaxpy : public base_impl<vaxpy, vaxpy_parameters>
|
class axpy : public base_impl<axpy, axpy_parameters>
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
virtual int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
|
virtual int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
|
||||||
std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const;
|
std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const;
|
||||||
public:
|
public:
|
||||||
vaxpy(vaxpy::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
axpy(axpy::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
||||||
vaxpy(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
axpy(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
||||||
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
||||||
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
@@ -14,6 +14,9 @@
|
|||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
|
||||||
|
namespace templates
|
||||||
|
{
|
||||||
|
|
||||||
enum fetching_policy_type
|
enum fetching_policy_type
|
||||||
{
|
{
|
||||||
FETCH_FROM_LOCAL,
|
FETCH_FROM_LOCAL,
|
||||||
@@ -147,8 +150,8 @@ protected:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void compute_reduction(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op);
|
static void compute_dot(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op);
|
||||||
static void compute_index_reduction(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op);
|
static void compute_index_dot(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op);
|
||||||
static void process_all(std::string const & type_key, std::string const & str, kernel_generation_stream & stream, std::vector<mapping_type> const & mappings);
|
static void process_all(std::string const & type_key, std::string const & str, kernel_generation_stream & stream, std::vector<mapping_type> const & mappings);
|
||||||
static void process_all_at(std::string const & type_key, std::string const & str, kernel_generation_stream & stream, std::vector<mapping_type> const & mappings, size_t root_idx, leaf_t leaf);
|
static void process_all_at(std::string const & type_key, std::string const & str, kernel_generation_stream & stream, std::vector<mapping_type> const & mappings, size_t root_idx, leaf_t leaf);
|
||||||
static std::string neutral_element(op_element const & op, driver::backend_type backend, std::string const & datatype);
|
static std::string neutral_element(op_element const & op, driver::backend_type backend, std::string const & datatype);
|
||||||
@@ -159,8 +162,8 @@ protected:
|
|||||||
static bool is_strided(array_expression::node const & node);
|
static bool is_strided(array_expression::node const & node);
|
||||||
static int_t vector_size(array_expression::node const & node);
|
static int_t vector_size(array_expression::node const & node);
|
||||||
static std::pair<int_t, int_t> matrix_size(array_expression::node const & node);
|
static std::pair<int_t, int_t> matrix_size(array_expression::node const & node);
|
||||||
static bool is_reduction(array_expression::node const & node);
|
static bool is_dot(array_expression::node const & node);
|
||||||
static bool is_index_reduction(op_element const & op);
|
static bool is_index_dot(op_element const & op);
|
||||||
static std::string access_vector_type(std::string const & v, int i);
|
static std::string access_vector_type(std::string const & v, int i);
|
||||||
|
|
||||||
tools::shared_ptr<symbolic_binder> make_binder();
|
tools::shared_ptr<symbolic_binder> make_binder();
|
||||||
@@ -204,6 +207,7 @@ protected:
|
|||||||
binding_policy_t binding_policy_;
|
binding_policy_t binding_policy_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -1,32 +1,34 @@
|
|||||||
#ifndef ISAAC_BACKEND_TEMPLATES_REDUCTION_H
|
#ifndef ISAAC_BACKEND_TEMPLATES_DOT_H
|
||||||
#define ISAAC_BACKEND_TEMPLATES_REDUCTION_H
|
#define ISAAC_BACKEND_TEMPLATES_DOT_H
|
||||||
|
|
||||||
#include "isaac/backend/templates/base.h"
|
#include "isaac/backend/templates/base.h"
|
||||||
|
|
||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
namespace templates
|
||||||
struct reduction_parameters : public base::parameters_type
|
|
||||||
{
|
{
|
||||||
reduction_parameters(unsigned int _simd_width,
|
|
||||||
|
struct dot_parameters : public base::parameters_type
|
||||||
|
{
|
||||||
|
dot_parameters(unsigned int _simd_width,
|
||||||
unsigned int _group_size, unsigned int _num_groups,
|
unsigned int _group_size, unsigned int _num_groups,
|
||||||
fetching_policy_type _fetching_policy);
|
fetching_policy_type _fetching_policy);
|
||||||
unsigned int num_groups;
|
unsigned int num_groups;
|
||||||
fetching_policy_type fetching_policy;
|
fetching_policy_type fetching_policy;
|
||||||
};
|
};
|
||||||
|
|
||||||
class reduction : public base_impl<reduction, reduction_parameters>
|
class dot : public base_impl<dot, dot_parameters>
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
unsigned int lmem_usage(expressions_tuple const & expressions) const;
|
unsigned int lmem_usage(expressions_tuple const & expressions) const;
|
||||||
int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
|
int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
|
||||||
inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_reduction*> exprs,
|
inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_dot*> exprs,
|
||||||
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const;
|
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const;
|
||||||
std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const;
|
std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
reduction(reduction::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
dot(dot::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
||||||
reduction(unsigned int simd, unsigned int ls, unsigned int ng, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
|
dot(unsigned int simd, unsigned int ls, unsigned int ng, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
|
||||||
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
||||||
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
||||||
private:
|
private:
|
||||||
@@ -34,6 +36,7 @@ private:
|
|||||||
std::vector< driver::Buffer > tmpidx_;
|
std::vector< driver::Buffer > tmpidx_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
@@ -7,12 +7,13 @@
|
|||||||
|
|
||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
namespace templates
|
||||||
|
{
|
||||||
class model;
|
class model;
|
||||||
|
|
||||||
struct mproduct_parameters : public base::parameters_type
|
struct gemm_parameters : public base::parameters_type
|
||||||
{
|
{
|
||||||
mproduct_parameters(unsigned int simd_width
|
gemm_parameters(unsigned int simd_width
|
||||||
, int_t local_size_0, int_t KL, int_t local_size_1, int_t D
|
, int_t local_size_0, int_t KL, int_t local_size_1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns
|
, int_t ms, int_t ks, int_t ns
|
||||||
, fetching_policy_type A_fetching_policy, fetching_policy_type B_fetching_policy
|
, fetching_policy_type A_fetching_policy, fetching_policy_type B_fetching_policy
|
||||||
@@ -38,7 +39,7 @@ struct mproduct_parameters : public base::parameters_type
|
|||||||
bool unroll_outer;
|
bool unroll_outer;
|
||||||
};
|
};
|
||||||
|
|
||||||
class mproduct : public base_impl<mproduct, mproduct_parameters>
|
class gemm : public base_impl<gemm, gemm_parameters>
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
unsigned int lmem_usage(expressions_tuple const & expressions) const;
|
unsigned int lmem_usage(expressions_tuple const & expressions) const;
|
||||||
@@ -50,7 +51,7 @@ private:
|
|||||||
array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap);
|
array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap);
|
||||||
std::vector<int_t> infos(expressions_tuple const & expressions, isaac::symbolic::preset::gemm::args &arguments) const;
|
std::vector<int_t> infos(expressions_tuple const & expressions, isaac::symbolic::preset::gemm::args &arguments) const;
|
||||||
public:
|
public:
|
||||||
mproduct(mproduct::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans);
|
gemm(gemm::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans);
|
||||||
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
||||||
void cleanup(values_holder beta, controller<expressions_tuple> const & ctr, model & fallback,
|
void cleanup(values_holder beta, controller<expressions_tuple> const & ctr, model & fallback,
|
||||||
lhs_rhs_element* eA, lhs_rhs_element* eB, lhs_rhs_element* eC, lhs_rhs_element* ebeta, array const & A, array const & B, array const & C);
|
lhs_rhs_element* eA, lhs_rhs_element* eB, lhs_rhs_element* eC, lhs_rhs_element* ebeta, array const & A, array const & B, array const & C);
|
||||||
@@ -62,41 +63,41 @@ private:
|
|||||||
bool check_bounds_;
|
bool check_bounds_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class mproduct_nn : public mproduct
|
class gemm_nn : public gemm
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
mproduct_nn(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
gemm_nn(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||||
, int_t lfetch0, int_t lfetch1, bool check_bound = false);
|
, int_t lfetch0, int_t lfetch1, bool check_bound = false);
|
||||||
};
|
};
|
||||||
|
|
||||||
class mproduct_tn : public mproduct
|
class gemm_tn : public gemm
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
mproduct_tn(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
gemm_tn(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||||
, int_t lfetch0, int_t lfetch1, bool check_bound = false);
|
, int_t lfetch0, int_t lfetch1, bool check_bound = false);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class mproduct_nt : public mproduct
|
class gemm_nt : public gemm
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
mproduct_nt(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
gemm_nt(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||||
, int_t lfetch0, int_t lfetch1, bool check_bound = false);
|
, int_t lfetch0, int_t lfetch1, bool check_bound = false);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class mproduct_tt : public mproduct
|
class gemm_tt : public gemm
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
mproduct_tt(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
gemm_tt(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||||
, int_t lfetch0, int_t lfetch1, bool check_bound = false);
|
, int_t lfetch0, int_t lfetch1, bool check_bound = false);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
61
include/isaac/backend/templates/gemv.h
Normal file
61
include/isaac/backend/templates/gemv.h
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
#ifndef ISAAC_BACKEND_TEMPLATES_MDOT_H
|
||||||
|
#define ISAAC_BACKEND_TEMPLATES_MDOT_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "isaac/symbolic/expression.h"
|
||||||
|
#include "isaac/backend/templates/base.h"
|
||||||
|
|
||||||
|
namespace isaac
|
||||||
|
{
|
||||||
|
namespace templates
|
||||||
|
{
|
||||||
|
struct gemv_parameters : public base::parameters_type
|
||||||
|
{
|
||||||
|
gemv_parameters(unsigned int _simd_width,
|
||||||
|
unsigned int _local_size_0, unsigned int _local_size_1,
|
||||||
|
unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetch_policy);
|
||||||
|
unsigned int num_groups_0;
|
||||||
|
unsigned int num_groups_1;
|
||||||
|
fetching_policy_type fetch_policy;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class gemv : public base_impl<gemv, gemv_parameters>
|
||||||
|
{
|
||||||
|
protected:
|
||||||
|
enum dot_type
|
||||||
|
{
|
||||||
|
REDUCE_ROWS,
|
||||||
|
REDUCE_COLUMNS
|
||||||
|
};
|
||||||
|
gemv(gemv::parameters_type const & , dot_type, binding_policy_t);
|
||||||
|
private:
|
||||||
|
virtual int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
|
||||||
|
unsigned int lmem_usage() const;
|
||||||
|
std::string generate_impl(const char * suffix, expressions_tuple const &, driver::Device const & device, std::vector<mapping_type> const &) const;
|
||||||
|
public:
|
||||||
|
virtual std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
||||||
|
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
||||||
|
private:
|
||||||
|
dot_type dot_type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class gemv_n : public gemv
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
gemv_n(gemv::parameters_type const &, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
||||||
|
gemv_n(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
|
||||||
|
};
|
||||||
|
|
||||||
|
class gemv_t : public gemv
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
gemv_t(gemv::parameters_type const &, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
||||||
|
gemv_t(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@@ -6,29 +6,32 @@
|
|||||||
|
|
||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
namespace templates
|
||||||
|
{
|
||||||
|
|
||||||
class maxpy_parameters : public base::parameters_type
|
class ger_parameters : public base::parameters_type
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
maxpy_parameters(unsigned int _simd_width, unsigned int _local_size_0, unsigned int _local_size_1, unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetching_policy);
|
ger_parameters(unsigned int _simd_width, unsigned int _local_size_0, unsigned int _local_size_1, unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetching_policy);
|
||||||
|
|
||||||
unsigned int num_groups_0;
|
unsigned int num_groups_0;
|
||||||
unsigned int num_groups_1;
|
unsigned int num_groups_1;
|
||||||
fetching_policy_type fetching_policy;
|
fetching_policy_type fetching_policy;
|
||||||
};
|
};
|
||||||
|
|
||||||
class maxpy : public base_impl<maxpy, maxpy_parameters>
|
class ger : public base_impl<ger, ger_parameters>
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
|
int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
|
||||||
std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const;
|
std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const;
|
||||||
public:
|
public:
|
||||||
maxpy(parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
ger(parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
||||||
maxpy(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
|
ger(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
|
||||||
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
||||||
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
@@ -1,59 +0,0 @@
|
|||||||
#ifndef ISAAC_BACKEND_TEMPLATES_MREDUCTION_H
|
|
||||||
#define ISAAC_BACKEND_TEMPLATES_MREDUCTION_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "isaac/symbolic/expression.h"
|
|
||||||
#include "isaac/backend/templates/base.h"
|
|
||||||
|
|
||||||
namespace isaac
|
|
||||||
{
|
|
||||||
|
|
||||||
struct mreduction_parameters : public base::parameters_type
|
|
||||||
{
|
|
||||||
mreduction_parameters(unsigned int _simd_width,
|
|
||||||
unsigned int _local_size_0, unsigned int _local_size_1,
|
|
||||||
unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetch_policy);
|
|
||||||
unsigned int num_groups_0;
|
|
||||||
unsigned int num_groups_1;
|
|
||||||
fetching_policy_type fetch_policy;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class mreduction : public base_impl<mreduction, mreduction_parameters>
|
|
||||||
{
|
|
||||||
protected:
|
|
||||||
enum reduction_type
|
|
||||||
{
|
|
||||||
REDUCE_ROWS,
|
|
||||||
REDUCE_COLUMNS
|
|
||||||
};
|
|
||||||
mreduction(mreduction::parameters_type const & , reduction_type, binding_policy_t);
|
|
||||||
private:
|
|
||||||
virtual int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
|
|
||||||
unsigned int lmem_usage() const;
|
|
||||||
std::string generate_impl(const char * suffix, expressions_tuple const &, driver::Device const & device, std::vector<mapping_type> const &) const;
|
|
||||||
public:
|
|
||||||
virtual std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
|
||||||
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
|
||||||
private:
|
|
||||||
reduction_type reduction_type_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class mreduction_rows : public mreduction
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
mreduction_rows(mreduction::parameters_type const &, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
|
||||||
mreduction_rows(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
|
|
||||||
};
|
|
||||||
|
|
||||||
class mreduction_cols : public mreduction
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
mreduction_cols(mreduction::parameters_type const &, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
|
||||||
mreduction_cols(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
@@ -14,7 +14,7 @@ namespace isaac
|
|||||||
|
|
||||||
class model
|
class model
|
||||||
{
|
{
|
||||||
typedef tools::shared_ptr<base> template_pointer;
|
typedef tools::shared_ptr<templates::base> template_pointer;
|
||||||
typedef std::vector< template_pointer > templates_container;
|
typedef std::vector< template_pointer > templates_container;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -23,8 +23,8 @@ namespace isaac
|
|||||||
driver::Program& init(controller<expressions_tuple> const &);
|
driver::Program& init(controller<expressions_tuple> const &);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
model(expression_type, numeric_type, predictors::random_forest const &, std::vector< tools::shared_ptr<base> > const &, driver::CommandQueue const &);
|
model(expression_type, numeric_type, predictors::random_forest const &, std::vector< tools::shared_ptr<templates::base> > const &, driver::CommandQueue const &);
|
||||||
model(expression_type, numeric_type, base const &, driver::CommandQueue const &);
|
model(expression_type, numeric_type, templates::base const &, driver::CommandQueue const &);
|
||||||
|
|
||||||
void execute(controller<expressions_tuple> const &);
|
void execute(controller<expressions_tuple> const &);
|
||||||
templates_container const & templates() const;
|
templates_container const & templates() const;
|
||||||
@@ -46,7 +46,7 @@ namespace isaac
|
|||||||
model_map_t init_models(driver::CommandQueue const & queue);
|
model_map_t init_models(driver::CommandQueue const & queue);
|
||||||
model_map_t& models(driver::CommandQueue & queue);
|
model_map_t& models(driver::CommandQueue & queue);
|
||||||
|
|
||||||
extern std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<base> > fallbacks;
|
extern std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > fallbacks;
|
||||||
extern std::map<driver::CommandQueue, model_map_t> models_;
|
extern std::map<driver::CommandQueue, model_map_t> models_;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -31,14 +31,14 @@ enum operation_node_type_family
|
|||||||
// BLAS1-type
|
// BLAS1-type
|
||||||
OPERATOR_UNARY_TYPE_FAMILY,
|
OPERATOR_UNARY_TYPE_FAMILY,
|
||||||
OPERATOR_BINARY_TYPE_FAMILY,
|
OPERATOR_BINARY_TYPE_FAMILY,
|
||||||
OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY,
|
OPERATOR_VECTOR_DOT_TYPE_FAMILY,
|
||||||
|
|
||||||
// BLAS2-type
|
// BLAS2-type
|
||||||
OPERATOR_ROWS_REDUCTION_TYPE_FAMILY,
|
OPERATOR_ROWS_DOT_TYPE_FAMILY,
|
||||||
OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY,
|
OPERATOR_COLUMNS_DOT_TYPE_FAMILY,
|
||||||
|
|
||||||
// BLAS3-type
|
// BLAS3-type
|
||||||
OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
|
OPERATOR_GEMM_TYPE_FAMILY
|
||||||
};
|
};
|
||||||
|
|
||||||
/** @brief Enumeration for identifying the possible operations */
|
/** @brief Enumeration for identifying the possible operations */
|
||||||
@@ -119,10 +119,10 @@ enum operation_node_type
|
|||||||
OPERATOR_SHIFT_TYPE,
|
OPERATOR_SHIFT_TYPE,
|
||||||
OPERATOR_VDIAG_TYPE,
|
OPERATOR_VDIAG_TYPE,
|
||||||
|
|
||||||
OPERATOR_MATRIX_PRODUCT_NN_TYPE,
|
OPERATOR_GEMM_NN_TYPE,
|
||||||
OPERATOR_MATRIX_PRODUCT_TN_TYPE,
|
OPERATOR_GEMM_TN_TYPE,
|
||||||
OPERATOR_MATRIX_PRODUCT_NT_TYPE,
|
OPERATOR_GEMM_NT_TYPE,
|
||||||
OPERATOR_MATRIX_PRODUCT_TT_TYPE,
|
OPERATOR_GEMM_TT_TYPE,
|
||||||
|
|
||||||
OPERATOR_PAIR_TYPE
|
OPERATOR_PAIR_TYPE
|
||||||
};
|
};
|
||||||
|
@@ -128,16 +128,15 @@ template<> struct to_numeric_type<double> { static const numeric_type value = DO
|
|||||||
enum expression_type
|
enum expression_type
|
||||||
{
|
{
|
||||||
INVALID_EXPRESSION_TYPE,
|
INVALID_EXPRESSION_TYPE,
|
||||||
SCALAR_AXPY_TYPE,
|
AXPY_TYPE,
|
||||||
VECTOR_AXPY_TYPE,
|
GER_TYPE,
|
||||||
MATRIX_AXPY_TYPE,
|
DOT_TYPE,
|
||||||
REDUCTION_TYPE,
|
GEMV_N_TYPE,
|
||||||
ROW_WISE_REDUCTION_TYPE,
|
GEMV_T_TYPE,
|
||||||
COL_WISE_REDUCTION_TYPE,
|
GEMM_NN_TYPE,
|
||||||
MATRIX_PRODUCT_NN_TYPE,
|
GEMM_TN_TYPE,
|
||||||
MATRIX_PRODUCT_TN_TYPE,
|
GEMM_NT_TYPE,
|
||||||
MATRIX_PRODUCT_NT_TYPE,
|
GEMM_TT_TYPE
|
||||||
MATRIX_PRODUCT_TT_TYPE
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct slice
|
struct slice
|
||||||
|
@@ -596,17 +596,17 @@ array_expression repmat(array_expression const & A, int_t const & rep1, int_t co
|
|||||||
|
|
||||||
///*--- Reductions ---*/
|
///*--- Reductions ---*/
|
||||||
////---------------------------------------
|
////---------------------------------------
|
||||||
#define DEFINE_REDUCTION(OP, OPNAME)\
|
#define DEFINE_DOT(OP, OPNAME)\
|
||||||
array_expression OPNAME(array const & x, int_t axis)\
|
array_expression OPNAME(array const & x, int_t axis)\
|
||||||
{\
|
{\
|
||||||
if(axis < -1 || axis > x.nshape())\
|
if(axis < -1 || axis > x.nshape())\
|
||||||
throw std::out_of_range("The axis entry is out of bounds");\
|
throw std::out_of_range("The axis entry is out of bounds");\
|
||||||
else if(axis==-1)\
|
else if(axis==-1)\
|
||||||
return array_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\
|
return array_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\
|
||||||
else if(axis==0)\
|
else if(axis==0)\
|
||||||
return array_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\
|
return array_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\
|
||||||
else\
|
else\
|
||||||
return array_expression(x, invalid_node(), op_element(OPERATOR_ROWS_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\
|
return array_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\
|
||||||
}\
|
}\
|
||||||
\
|
\
|
||||||
array_expression OPNAME(array_expression const & x, int_t axis)\
|
array_expression OPNAME(array_expression const & x, int_t axis)\
|
||||||
@@ -614,20 +614,20 @@ array_expression OPNAME(array_expression const & x, int_t axis)\
|
|||||||
if(axis < -1 || axis > x.nshape())\
|
if(axis < -1 || axis > x.nshape())\
|
||||||
throw std::out_of_range("The axis entry is out of bounds");\
|
throw std::out_of_range("The axis entry is out of bounds");\
|
||||||
if(axis==-1)\
|
if(axis==-1)\
|
||||||
return array_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\
|
return array_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\
|
||||||
else if(axis==0)\
|
else if(axis==0)\
|
||||||
return array_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\
|
return array_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\
|
||||||
else\
|
else\
|
||||||
return array_expression(x, invalid_node(), op_element(OPERATOR_ROWS_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\
|
return array_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFINE_REDUCTION(OPERATOR_ADD_TYPE, sum)
|
DEFINE_DOT(OPERATOR_ADD_TYPE, sum)
|
||||||
DEFINE_REDUCTION(OPERATOR_ELEMENT_ARGMAX_TYPE, argmax)
|
DEFINE_DOT(OPERATOR_ELEMENT_ARGMAX_TYPE, argmax)
|
||||||
DEFINE_REDUCTION(OPERATOR_ELEMENT_MAX_TYPE, max)
|
DEFINE_DOT(OPERATOR_ELEMENT_MAX_TYPE, max)
|
||||||
DEFINE_REDUCTION(OPERATOR_ELEMENT_MIN_TYPE, min)
|
DEFINE_DOT(OPERATOR_ELEMENT_MIN_TYPE, min)
|
||||||
DEFINE_REDUCTION(OPERATOR_ELEMENT_ARGMIN_TYPE, argmin)
|
DEFINE_DOT(OPERATOR_ELEMENT_ARGMIN_TYPE, argmin)
|
||||||
|
|
||||||
#undef DEFINE_REDUCTION
|
#undef DEFINE_DOT
|
||||||
|
|
||||||
namespace detail
|
namespace detail
|
||||||
{
|
{
|
||||||
@@ -635,21 +635,21 @@ namespace detail
|
|||||||
array_expression matmatprod(array const & A, array const & B)
|
array_expression matmatprod(array const & A, array const & B)
|
||||||
{
|
{
|
||||||
size4 shape(A.shape()[0], B.shape()[1]);
|
size4 shape(A.shape()[0], B.shape()[1]);
|
||||||
return array_expression(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, OPERATOR_MATRIX_PRODUCT_NN_TYPE), A.context(), A.dtype(), shape);
|
return array_expression(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, OPERATOR_GEMM_NN_TYPE), A.context(), A.dtype(), shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
array_expression matmatprod(array_expression const & A, array const & B)
|
array_expression matmatprod(array_expression const & A, array const & B)
|
||||||
{
|
{
|
||||||
operation_node_type type = OPERATOR_MATRIX_PRODUCT_NN_TYPE;
|
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
|
||||||
size4 shape(A.shape()[0], B.shape()[1]);
|
size4 shape(A.shape()[0], B.shape()[1]);
|
||||||
|
|
||||||
array_expression::node & A_root = const_cast<array_expression::node &>(A.tree()[A.root()]);
|
array_expression::node & A_root = const_cast<array_expression::node &>(A.tree()[A.root()]);
|
||||||
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
|
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
|
||||||
if(A_trans){
|
if(A_trans){
|
||||||
type = OPERATOR_MATRIX_PRODUCT_TN_TYPE;
|
type = OPERATOR_GEMM_TN_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
array_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||||
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
|
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
|
||||||
if(A_trans) res_root.lhs = A_root.lhs;
|
if(A_trans) res_root.lhs = A_root.lhs;
|
||||||
return res;
|
return res;
|
||||||
@@ -657,16 +657,16 @@ namespace detail
|
|||||||
|
|
||||||
array_expression matmatprod(array const & A, array_expression const & B)
|
array_expression matmatprod(array const & A, array_expression const & B)
|
||||||
{
|
{
|
||||||
operation_node_type type = OPERATOR_MATRIX_PRODUCT_NN_TYPE;
|
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
|
||||||
size4 shape(A.shape()[0], B.shape()[1]);
|
size4 shape(A.shape()[0], B.shape()[1]);
|
||||||
|
|
||||||
array_expression::node & B_root = const_cast<array_expression::node &>(B.tree()[B.root()]);
|
array_expression::node & B_root = const_cast<array_expression::node &>(B.tree()[B.root()]);
|
||||||
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
|
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
|
||||||
if(B_trans){
|
if(B_trans){
|
||||||
type = OPERATOR_MATRIX_PRODUCT_NT_TYPE;
|
type = OPERATOR_GEMM_NT_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
array_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||||
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
|
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
|
||||||
if(B_trans) res_root.rhs = B_root.lhs;
|
if(B_trans) res_root.rhs = B_root.lhs;
|
||||||
return res;
|
return res;
|
||||||
@@ -674,7 +674,7 @@ namespace detail
|
|||||||
|
|
||||||
array_expression matmatprod(array_expression const & A, array_expression const & B)
|
array_expression matmatprod(array_expression const & A, array_expression const & B)
|
||||||
{
|
{
|
||||||
operation_node_type type = OPERATOR_MATRIX_PRODUCT_NN_TYPE;
|
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
|
||||||
array_expression::node & A_root = const_cast<array_expression::node &>(A.tree()[A.root()]);
|
array_expression::node & A_root = const_cast<array_expression::node &>(A.tree()[A.root()]);
|
||||||
array_expression::node & B_root = const_cast<array_expression::node &>(B.tree()[B.root()]);
|
array_expression::node & B_root = const_cast<array_expression::node &>(B.tree()[B.root()]);
|
||||||
size4 shape(A.shape()[0], B.shape()[1]);
|
size4 shape(A.shape()[0], B.shape()[1]);
|
||||||
@@ -682,12 +682,12 @@ namespace detail
|
|||||||
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
|
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
|
||||||
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
|
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
|
||||||
|
|
||||||
if(A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_TT_TYPE;
|
if(A_trans && B_trans) type = OPERATOR_GEMM_TT_TYPE;
|
||||||
else if(A_trans && !B_trans) type = OPERATOR_MATRIX_PRODUCT_TN_TYPE;
|
else if(A_trans && !B_trans) type = OPERATOR_GEMM_TN_TYPE;
|
||||||
else if(!A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_NT_TYPE;
|
else if(!A_trans && B_trans) type = OPERATOR_GEMM_NT_TYPE;
|
||||||
else type = OPERATOR_MATRIX_PRODUCT_NN_TYPE;
|
else type = OPERATOR_GEMM_NN_TYPE;
|
||||||
|
|
||||||
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
array_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
|
||||||
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
|
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
|
||||||
if(A_trans) res_root.lhs = A_root.lhs;
|
if(A_trans) res_root.lhs = A_root.lhs;
|
||||||
if(B_trans) res_root.rhs = B_root.lhs;
|
if(B_trans) res_root.rhs = B_root.lhs;
|
||||||
|
@@ -102,23 +102,23 @@ std::string binary_leaf::evaluate_recursive(leaf_t leaf, std::map<std::string, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
mapped_mproduct::mapped_mproduct(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "mproduct"), binary_leaf(info) { }
|
mapped_gemm::mapped_gemm(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "gemm"), binary_leaf(info) { }
|
||||||
|
|
||||||
//
|
//
|
||||||
mapped_reduction::mapped_reduction(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key) :
|
mapped_dot::mapped_dot(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key) :
|
||||||
mapped_object(scalartype, id, type_key), binary_leaf(info)
|
mapped_object(scalartype, id, type_key), binary_leaf(info)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
int_t mapped_reduction::root_idx() const
|
int_t mapped_dot::root_idx() const
|
||||||
{ return info_.root_idx; }
|
{ return info_.root_idx; }
|
||||||
|
|
||||||
isaac::array_expression const & mapped_reduction::array_expression() const
|
isaac::array_expression const & mapped_dot::array_expression() const
|
||||||
{ return *info_.array_expression; }
|
{ return *info_.array_expression; }
|
||||||
|
|
||||||
array_expression::node mapped_reduction::root_node() const
|
array_expression::node mapped_dot::root_node() const
|
||||||
{ return array_expression().tree()[root_idx()]; }
|
{ return array_expression().tree()[root_idx()]; }
|
||||||
|
|
||||||
bool mapped_reduction::is_index_reduction() const
|
bool mapped_dot::is_index_dot() const
|
||||||
{
|
{
|
||||||
op_element const & op = root_op();
|
op_element const & op = root_op();
|
||||||
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
|
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
|
||||||
@@ -127,17 +127,17 @@ bool mapped_reduction::is_index_reduction() const
|
|||||||
|| op.type==OPERATOR_ELEMENT_ARGMIN_TYPE;
|
|| op.type==OPERATOR_ELEMENT_ARGMIN_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
op_element mapped_reduction::root_op() const
|
op_element mapped_dot::root_op() const
|
||||||
{
|
{
|
||||||
return info_.array_expression->tree()[info_.root_idx].op;
|
return info_.array_expression->tree()[info_.root_idx].op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
mapped_scalar_reduction::mapped_scalar_reduction(std::string const & scalartype, unsigned int id, node_info info) : mapped_reduction(scalartype, id, info, "scalar_reduction"){ }
|
mapped_scalar_dot::mapped_scalar_dot(std::string const & scalartype, unsigned int id, node_info info) : mapped_dot(scalartype, id, info, "scalar_dot"){ }
|
||||||
|
|
||||||
//
|
//
|
||||||
mapped_mreduction::mapped_mreduction(std::string const & scalartype, unsigned int id, node_info info) : mapped_reduction(scalartype, id, info, "mreduction") { }
|
mapped_gemv::mapped_gemv(std::string const & scalartype, unsigned int id, node_info info) : mapped_dot(scalartype, id, info, "gemv") { }
|
||||||
|
|
||||||
//
|
//
|
||||||
void mapped_host_scalar::preprocess(std::string & str) const
|
void mapped_host_scalar::preprocess(std::string & str) const
|
||||||
|
@@ -10,15 +10,15 @@ namespace detail
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool is_scalar_reduction(array_expression::node const & node)
|
bool is_scalar_dot(array_expression::node const & node)
|
||||||
{
|
{
|
||||||
return node.op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY;
|
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_vector_reduction(array_expression::node const & node)
|
bool is_vector_dot(array_expression::node const & node)
|
||||||
{
|
{
|
||||||
return node.op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|
return node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|
||||||
|| node.op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY;
|
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_assignment(op_element const & op)
|
bool is_assignment(op_element const & op)
|
||||||
@@ -75,10 +75,10 @@ namespace detail
|
|||||||
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|
||||||
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|
||||||
|| op.type==OPERATOR_OUTER_PROD_TYPE
|
|| op.type==OPERATOR_OUTER_PROD_TYPE
|
||||||
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|
|| op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|
||||||
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|
|| op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|
||||||
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|
|| op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|
||||||
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
|
|| op.type_family==OPERATOR_GEMM_TYPE_FAMILY
|
||||||
;
|
;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -214,10 +214,10 @@ const char * evaluate(operation_node_type type)
|
|||||||
case OPERATOR_ELEMENT_MIN_TYPE : return "min";
|
case OPERATOR_ELEMENT_MIN_TYPE : return "min";
|
||||||
|
|
||||||
//Binary
|
//Binary
|
||||||
case OPERATOR_MATRIX_PRODUCT_NN_TYPE : return "prodNN";
|
case OPERATOR_GEMM_NN_TYPE : return "prodNN";
|
||||||
case OPERATOR_MATRIX_PRODUCT_TN_TYPE : return "prodTN";
|
case OPERATOR_GEMM_TN_TYPE : return "prodTN";
|
||||||
case OPERATOR_MATRIX_PRODUCT_NT_TYPE : return "prodNT";
|
case OPERATOR_GEMM_NT_TYPE : return "prodNT";
|
||||||
case OPERATOR_MATRIX_PRODUCT_TT_TYPE : return "prodTT";
|
case OPERATOR_GEMM_TT_TYPE : return "prodTT";
|
||||||
case OPERATOR_VDIAG_TYPE : return "vdiag";
|
case OPERATOR_VDIAG_TYPE : return "vdiag";
|
||||||
case OPERATOR_MATRIX_DIAG_TYPE : return "mdiag";
|
case OPERATOR_MATRIX_DIAG_TYPE : return "mdiag";
|
||||||
case OPERATOR_MATRIX_ROW_TYPE : return "row";
|
case OPERATOR_MATRIX_ROW_TYPE : return "row";
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
#include "isaac/backend/templates/vaxpy.h"
|
#include "isaac/backend/templates/axpy.h"
|
||||||
#include "isaac/backend/keywords.h"
|
#include "isaac/backend/keywords.h"
|
||||||
#include "isaac/driver/backend.h"
|
#include "isaac/driver/backend.h"
|
||||||
#include "isaac/tools/make_map.hpp"
|
#include "isaac/tools/make_map.hpp"
|
||||||
@@ -8,23 +8,24 @@
|
|||||||
|
|
||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
namespace templates
|
||||||
|
{
|
||||||
|
|
||||||
|
axpy_parameters::axpy_parameters(unsigned int _simd_width,
|
||||||
vaxpy_parameters::vaxpy_parameters(unsigned int _simd_width,
|
|
||||||
unsigned int _group_size, unsigned int _num_groups,
|
unsigned int _group_size, unsigned int _num_groups,
|
||||||
fetching_policy_type _fetching_policy) :
|
fetching_policy_type _fetching_policy) :
|
||||||
base::parameters_type(_simd_width, _group_size, 1, 1), num_groups(_num_groups), fetching_policy(_fetching_policy)
|
base::parameters_type(_simd_width, _group_size, 1, 1), num_groups(_num_groups), fetching_policy(_fetching_policy)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
|
|
||||||
int vaxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
int axpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||||
{
|
{
|
||||||
if (p_.fetching_policy==FETCH_FROM_LOCAL)
|
if (p_.fetching_policy==FETCH_FROM_LOCAL)
|
||||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||||
return TEMPLATE_VALID;
|
return TEMPLATE_VALID;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string vaxpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
std::string axpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
||||||
{
|
{
|
||||||
driver::backend_type backend = device.backend();
|
driver::backend_type backend = device.backend();
|
||||||
std::string _size_t = size_type(device);
|
std::string _size_t = size_type(device);
|
||||||
@@ -90,25 +91,24 @@ std::string vaxpy::generate_impl(const char * suffix, expressions_tuple const &
|
|||||||
return stream.str();
|
return stream.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
vaxpy::vaxpy(vaxpy_parameters const & parameters,
|
axpy::axpy(axpy_parameters const & parameters,
|
||||||
binding_policy_t binding_policy) :
|
binding_policy_t binding_policy) :
|
||||||
base_impl<vaxpy, vaxpy_parameters>(parameters, binding_policy)
|
base_impl<axpy, axpy_parameters>(parameters, binding_policy)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
vaxpy::vaxpy(unsigned int simd, unsigned int ls, unsigned int ng,
|
axpy::axpy(unsigned int simd, unsigned int ls, unsigned int ng,
|
||||||
fetching_policy_type fetch, binding_policy_t bind):
|
fetching_policy_type fetch, binding_policy_t bind):
|
||||||
base_impl<vaxpy, vaxpy_parameters>(vaxpy_parameters(simd,ls,ng,fetch), bind)
|
base_impl<axpy, axpy_parameters>(axpy_parameters(simd,ls,ng,fetch), bind)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
|
|
||||||
std::vector<int_t> vaxpy::input_sizes(expressions_tuple const & expressions) const
|
std::vector<int_t> axpy::input_sizes(expressions_tuple const & expressions) const
|
||||||
{
|
{
|
||||||
size4 shape = static_cast<array_expression const *>(expressions.data().front().get())->shape();
|
size4 shape = static_cast<array_expression const *>(expressions.data().front().get())->shape();
|
||||||
int_t size = static_cast<array_expression const *>(expressions.data().front().get())->shape()[0];
|
|
||||||
return tools::make_vector<int_t>() << std::max(shape[0], shape[1]);
|
return tools::make_vector<int_t>() << std::max(shape[0], shape[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void vaxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
void axpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
||||||
{
|
{
|
||||||
expressions_tuple const & expressions = controller.x();
|
expressions_tuple const & expressions = controller.x();
|
||||||
//Size
|
//Size
|
||||||
@@ -135,3 +135,4 @@ void vaxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, con
|
|||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
}
|
@@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
#include "isaac/array.h"
|
#include "isaac/array.h"
|
||||||
#include "isaac/backend/keywords.h"
|
#include "isaac/backend/keywords.h"
|
||||||
#include "isaac/backend/templates/vaxpy.h"
|
#include "isaac/backend/templates/axpy.h"
|
||||||
#include "isaac/backend/templates/reduction.h"
|
#include "isaac/backend/templates/dot.h"
|
||||||
#include "isaac/backend/templates/maxpy.h"
|
#include "isaac/backend/templates/ger.h"
|
||||||
#include "isaac/backend/templates/mreduction.h"
|
#include "isaac/backend/templates/gemv.h"
|
||||||
#include "isaac/backend/templates/mproduct.h"
|
#include "isaac/backend/templates/gemm.h"
|
||||||
#include "isaac/backend/templates/base.h"
|
#include "isaac/backend/templates/base.h"
|
||||||
#include "isaac/backend/parse.h"
|
#include "isaac/backend/parse.h"
|
||||||
#include "isaac/exception/operation_not_supported.h"
|
#include "isaac/exception/operation_not_supported.h"
|
||||||
@@ -17,6 +17,8 @@
|
|||||||
|
|
||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
namespace templates
|
||||||
|
{
|
||||||
|
|
||||||
base::parameters_type::parameters_type(unsigned int _simd_width, int_t _local_size_1, int_t _local_size_2, int_t _num_kernels) : simd_width(_simd_width), local_size_0(_local_size_1), local_size_1(_local_size_2), num_kernels(_num_kernels)
|
base::parameters_type::parameters_type(unsigned int _simd_width, int_t _local_size_1, int_t _local_size_2, int_t _num_kernels) : simd_width(_simd_width), local_size_0(_local_size_1), local_size_1(_local_size_2), num_kernels(_num_kernels)
|
||||||
{ }
|
{ }
|
||||||
@@ -102,12 +104,12 @@ void base::map_functor::operator()(isaac::array_expression const & array_express
|
|||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&array_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&array_expression, root_idx, &mapping_)));
|
||||||
else if (root_node.op.type==OPERATOR_MATRIX_COLUMN_TYPE)
|
else if (root_node.op.type==OPERATOR_MATRIX_COLUMN_TYPE)
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&array_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&array_expression, root_idx, &mapping_)));
|
||||||
else if (detail::is_scalar_reduction(root_node))
|
else if (detail::is_scalar_dot(root_node))
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_scalar_reduction>(&array_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_scalar_dot>(&array_expression, root_idx, &mapping_)));
|
||||||
else if (detail::is_vector_reduction(root_node))
|
else if (detail::is_vector_dot(root_node))
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mreduction>(&array_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_gemv>(&array_expression, root_idx, &mapping_)));
|
||||||
else if (root_node.op.type_family == OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
|
else if (root_node.op.type_family == OPERATOR_GEMM_TYPE_FAMILY)
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mproduct>(&array_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_gemm>(&array_expression, root_idx, &mapping_)));
|
||||||
else if (root_node.op.type == OPERATOR_REPEAT_TYPE)
|
else if (root_node.op.type == OPERATOR_REPEAT_TYPE)
|
||||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&array_expression, root_idx, &mapping_)));
|
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&array_expression, root_idx, &mapping_)));
|
||||||
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
|
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
|
||||||
@@ -198,7 +200,7 @@ void base::set_arguments_functor::operator()(isaac::array_expression const & arr
|
|||||||
set_arguments(root_node.rhs);
|
set_arguments(root_node.rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void base::compute_reduction(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op)
|
void base::compute_dot(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op)
|
||||||
{
|
{
|
||||||
if (detail::is_elementwise_function(op))
|
if (detail::is_elementwise_function(op))
|
||||||
os << acc << "=" << evaluate(op.type) << "(" << acc << "," << cur << ");" << std::endl;
|
os << acc << "=" << evaluate(op.type) << "(" << acc << "," << cur << ");" << std::endl;
|
||||||
@@ -206,7 +208,7 @@ void base::compute_reduction(kernel_generation_stream & os, std::string acc, std
|
|||||||
os << acc << "= (" << acc << ")" << evaluate(op.type) << "(" << cur << ");" << std::endl;
|
os << acc << "= (" << acc << ")" << evaluate(op.type) << "(" << cur << ");" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void base::compute_index_reduction(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op)
|
void base::compute_index_dot(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op)
|
||||||
{
|
{
|
||||||
// os << acc << " = " << cur_value << ">" << acc_value << "?" << cur << ":" << acc << ";" << std::endl;
|
// os << acc << " = " << cur_value << ">" << acc_value << "?" << cur << ":" << acc << ";" << std::endl;
|
||||||
os << acc << "= select(" << acc << "," << cur << "," << cur_value << ">" << acc_value << ");" << std::endl;
|
os << acc << "= select(" << acc << "," << cur << "," << cur_value << ">" << acc_value << ");" << std::endl;
|
||||||
@@ -259,7 +261,7 @@ std::string base::neutral_element(op_element const & op, driver::backend_type ba
|
|||||||
case OPERATOR_ELEMENT_MIN_TYPE : return INF;
|
case OPERATOR_ELEMENT_MIN_TYPE : return INF;
|
||||||
case OPERATOR_ELEMENT_ARGMIN_TYPE : return INF;
|
case OPERATOR_ELEMENT_ARGMIN_TYPE : return INF;
|
||||||
|
|
||||||
default: throw operation_not_supported_exception("Unsupported reduction operator : no neutral element known");
|
default: throw operation_not_supported_exception("Unsupported dot operator : no neutral element known");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -399,14 +401,14 @@ std::pair<int_t, int_t> base::matrix_size(array_expression::node const & node)
|
|||||||
return std::make_pair(node.lhs.array->shape()[0],node.lhs.array->shape()[1]);
|
return std::make_pair(node.lhs.array->shape()[0],node.lhs.array->shape()[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool base::is_reduction(array_expression::node const & node)
|
bool base::is_dot(array_expression::node const & node)
|
||||||
{
|
{
|
||||||
return node.op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|
||||||
|| node.op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|
||||||
|| node.op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY;
|
|| node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool base::is_index_reduction(op_element const & op)
|
bool base::is_index_dot(op_element const & op)
|
||||||
{
|
{
|
||||||
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
|
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
|
||||||
|| op.type==OPERATOR_ELEMENT_ARGMAX_TYPE
|
|| op.type==OPERATOR_ELEMENT_ARGMAX_TYPE
|
||||||
@@ -566,10 +568,11 @@ int base_impl<TType, PType>::is_invalid(expressions_tuple const & expressions, d
|
|||||||
return is_invalid_impl(device, expressions);
|
return is_invalid_impl(device, expressions);
|
||||||
}
|
}
|
||||||
|
|
||||||
template class base_impl<vaxpy, vaxpy_parameters>;
|
template class base_impl<axpy, axpy_parameters>;
|
||||||
template class base_impl<reduction, reduction_parameters>;
|
template class base_impl<dot, dot_parameters>;
|
||||||
template class base_impl<maxpy, maxpy_parameters>;
|
template class base_impl<ger, ger_parameters>;
|
||||||
template class base_impl<mreduction, mreduction_parameters>;
|
template class base_impl<gemv, gemv_parameters>;
|
||||||
template class base_impl<mproduct, mproduct_parameters>;
|
template class base_impl<gemm, gemm_parameters>;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include "isaac/backend/templates/reduction.h"
|
#include "isaac/backend/templates/dot.h"
|
||||||
#include <CL/cl.hpp>
|
#include <CL/cl.hpp>
|
||||||
#include "isaac/tools/to_string.hpp"
|
#include "isaac/tools/to_string.hpp"
|
||||||
#include "isaac/tools/make_map.hpp"
|
#include "isaac/tools/make_map.hpp"
|
||||||
@@ -7,13 +7,14 @@
|
|||||||
#include "isaac/backend/keywords.h"
|
#include "isaac/backend/keywords.h"
|
||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
namespace templates
|
||||||
reduction_parameters::reduction_parameters(unsigned int _simd_width,
|
{
|
||||||
|
dot_parameters::dot_parameters(unsigned int _simd_width,
|
||||||
unsigned int _group_size, unsigned int _num_groups,
|
unsigned int _group_size, unsigned int _num_groups,
|
||||||
fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _group_size, 1, 2), num_groups(_num_groups), fetching_policy(_fetching_policy)
|
fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _group_size, 1, 2), num_groups(_num_groups), fetching_policy(_fetching_policy)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
unsigned int reduction::lmem_usage(expressions_tuple const & expressions) const
|
unsigned int dot::lmem_usage(expressions_tuple const & expressions) const
|
||||||
{
|
{
|
||||||
unsigned int res = 0;
|
unsigned int res = 0;
|
||||||
for(const auto & elem : expressions.data())
|
for(const auto & elem : expressions.data())
|
||||||
@@ -24,14 +25,14 @@ unsigned int reduction::lmem_usage(expressions_tuple const & expressions) const
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
int reduction::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
int dot::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||||
{
|
{
|
||||||
if (p_.fetching_policy==FETCH_FROM_LOCAL)
|
if (p_.fetching_policy==FETCH_FROM_LOCAL)
|
||||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||||
return TEMPLATE_VALID;
|
return TEMPLATE_VALID;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void reduction::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_reduction*> exprs,
|
inline void dot::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_dot*> exprs,
|
||||||
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const
|
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const
|
||||||
{
|
{
|
||||||
stream << "#pragma unroll" << std::endl;
|
stream << "#pragma unroll" << std::endl;
|
||||||
@@ -44,26 +45,26 @@ inline void reduction::reduce_1d_local_memory(kernel_generation_stream & stream,
|
|||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
|
|
||||||
for (auto & expr : exprs)
|
for (auto & expr : exprs)
|
||||||
if (expr->is_index_reduction())
|
if (expr->is_index_dot())
|
||||||
compute_index_reduction(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]")
|
compute_index_dot(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]")
|
||||||
, expr->process(buf_value_str+"[lid]"), expr->process(buf_value_str+"[lid+stride]"),
|
, expr->process(buf_value_str+"[lid]"), expr->process(buf_value_str+"[lid+stride]"),
|
||||||
expr->root_op());
|
expr->root_op());
|
||||||
else
|
else
|
||||||
compute_reduction(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]"), expr->root_op());
|
compute_dot(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]"), expr->root_op());
|
||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string reduction::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
std::string dot::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
||||||
{
|
{
|
||||||
kernel_generation_stream stream;
|
kernel_generation_stream stream;
|
||||||
|
|
||||||
std::vector<mapped_scalar_reduction*> exprs;
|
std::vector<mapped_scalar_dot*> exprs;
|
||||||
for (const auto & mapping : mappings)
|
for (const auto & mapping : mappings)
|
||||||
for (mapping_type::const_iterator iit = mapping.begin(); iit != mapping.end(); ++iit)
|
for (mapping_type::const_iterator iit = mapping.begin(); iit != mapping.end(); ++iit)
|
||||||
if (mapped_scalar_reduction * p = dynamic_cast<mapped_scalar_reduction*>(iit->second.get()))
|
if (mapped_scalar_dot * p = dynamic_cast<mapped_scalar_dot*>(iit->second.get()))
|
||||||
exprs.push_back(p);
|
exprs.push_back(p);
|
||||||
std::size_t N = exprs.size();
|
std::size_t N = exprs.size();
|
||||||
driver::backend_type backend = device.backend();
|
driver::backend_type backend = device.backend();
|
||||||
@@ -73,7 +74,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
|||||||
for (unsigned int k = 0; k < N; ++k)
|
for (unsigned int k = 0; k < N; ++k)
|
||||||
{
|
{
|
||||||
std::string numeric_type = numeric_type_to_string(lhs_most(exprs[k]->array_expression().tree(), exprs[k]->array_expression().root()).lhs.dtype);
|
std::string numeric_type = numeric_type_to_string(lhs_most(exprs[k]->array_expression().tree(), exprs[k]->array_expression().root()).lhs.dtype);
|
||||||
if (exprs[k]->is_index_reduction())
|
if (exprs[k]->is_index_dot())
|
||||||
{
|
{
|
||||||
arguments += exprs[k]->process(Global(backend).get() + " unsigned int* #name_temp, ");
|
arguments += exprs[k]->process(Global(backend).get() + " unsigned int* #name_temp, ");
|
||||||
arguments += exprs[k]->process(Global(backend).get() + " " + tools::to_string(numeric_type) + "* #name_temp_value, ");
|
arguments += exprs[k]->process(Global(backend).get() + " " + tools::to_string(numeric_type) + "* #name_temp_value, ");
|
||||||
@@ -112,7 +113,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
|||||||
|
|
||||||
for (unsigned int k = 0; k < N; ++k)
|
for (unsigned int k = 0; k < N; ++k)
|
||||||
{
|
{
|
||||||
if (exprs[k]->is_index_reduction())
|
if (exprs[k]->is_index_dot())
|
||||||
{
|
{
|
||||||
stream << exprs[k]->process(Local(backend).get() + " #scalartype #name_buf_value[" + tools::to_string(p_.local_size_0) + "];") << std::endl;
|
stream << exprs[k]->process(Local(backend).get() + " #scalartype #name_buf_value[" + tools::to_string(p_.local_size_0) + "];") << std::endl;
|
||||||
stream << exprs[k]->process("#scalartype #name_acc_value = " + neutral_element(exprs[k]->root_op(), backend, "#scalartype") + ";") << std::endl;
|
stream << exprs[k]->process("#scalartype #name_acc_value = " + neutral_element(exprs[k]->root_op(), backend, "#scalartype") + ";") << std::endl;
|
||||||
@@ -156,11 +157,11 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
|||||||
accessors["matrix_diag"] = str[a];
|
accessors["matrix_diag"] = str[a];
|
||||||
accessors["array0"] = "#namereg";
|
accessors["array0"] = "#namereg";
|
||||||
std::string value = elem->evaluate_recursive(LHS_NODE_TYPE, accessors);
|
std::string value = elem->evaluate_recursive(LHS_NODE_TYPE, accessors);
|
||||||
if (elem->is_index_reduction())
|
if (elem->is_index_dot())
|
||||||
compute_index_reduction(stream, elem->process("#name_acc"), "i*" + tools::to_string(simd_width) + "+"
|
compute_index_dot(stream, elem->process("#name_acc"), "i*" + tools::to_string(simd_width) + "+"
|
||||||
+ tools::to_string(a), elem->process("#name_acc_value"), value,elem->root_op());
|
+ tools::to_string(a), elem->process("#name_acc_value"), value,elem->root_op());
|
||||||
else
|
else
|
||||||
compute_reduction(stream, elem->process("#name_acc"), value,elem->root_op());
|
compute_dot(stream, elem->process("#name_acc"), value,elem->root_op());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -168,7 +169,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
|||||||
//Fills local memory
|
//Fills local memory
|
||||||
for (unsigned int k = 0; k < N; ++k)
|
for (unsigned int k = 0; k < N; ++k)
|
||||||
{
|
{
|
||||||
if (exprs[k]->is_index_reduction())
|
if (exprs[k]->is_index_dot())
|
||||||
stream << exprs[k]->process("#name_buf_value[lid] = #name_acc_value;") << std::endl;
|
stream << exprs[k]->process("#name_buf_value[lid] = #name_acc_value;") << std::endl;
|
||||||
stream << exprs[k]->process("#name_buf[lid] = #name_acc;") << std::endl;
|
stream << exprs[k]->process("#name_buf[lid] = #name_acc;") << std::endl;
|
||||||
}
|
}
|
||||||
@@ -182,7 +183,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
|||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
for (unsigned int k = 0; k < N; ++k)
|
for (unsigned int k = 0; k < N; ++k)
|
||||||
{
|
{
|
||||||
if (exprs[k]->is_index_reduction())
|
if (exprs[k]->is_index_dot())
|
||||||
stream << exprs[k]->process("#name_temp_value[gpid] = #name_buf_value[0];") << std::endl;
|
stream << exprs[k]->process("#name_temp_value[gpid] = #name_buf_value[0];") << std::endl;
|
||||||
stream << exprs[k]->process("#name_temp[gpid] = #name_buf[0];") << std::endl;
|
stream << exprs[k]->process("#name_temp[gpid] = #name_buf[0];") << std::endl;
|
||||||
}
|
}
|
||||||
@@ -205,9 +206,9 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
|||||||
stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl;
|
stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl;
|
||||||
stream << "unsigned int lsize = " <<LocalSize0(backend) << ";" << std::endl;
|
stream << "unsigned int lsize = " <<LocalSize0(backend) << ";" << std::endl;
|
||||||
|
|
||||||
for (mapped_scalar_reduction* e: exprs)
|
for (mapped_scalar_dot* e: exprs)
|
||||||
{
|
{
|
||||||
if (e->is_index_reduction())
|
if (e->is_index_dot())
|
||||||
{
|
{
|
||||||
stream << e->process(Local(backend).get() + " unsigned int #name_buf[" + tools::to_string(p_.local_size_0) + "];");
|
stream << e->process(Local(backend).get() + " unsigned int #name_buf[" + tools::to_string(p_.local_size_0) + "];");
|
||||||
stream << e->process("unsigned int #name_acc = 0;") << std::endl;
|
stream << e->process("unsigned int #name_acc = 0;") << std::endl;
|
||||||
@@ -224,18 +225,18 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
|||||||
stream << "for(unsigned int i = lid; i < " << p_.num_groups << "; i += lsize)" << std::endl;
|
stream << "for(unsigned int i = lid; i < " << p_.num_groups << "; i += lsize)" << std::endl;
|
||||||
stream << "{" << std::endl;
|
stream << "{" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
for (mapped_scalar_reduction* e: exprs)
|
for (mapped_scalar_dot* e: exprs)
|
||||||
if (e->is_index_reduction())
|
if (e->is_index_dot())
|
||||||
compute_index_reduction(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->process("#name_acc_value"),e->process("#name_temp_value[i]"),e->root_op());
|
compute_index_dot(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->process("#name_acc_value"),e->process("#name_temp_value[i]"),e->root_op());
|
||||||
else
|
else
|
||||||
compute_reduction(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->root_op());
|
compute_dot(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->root_op());
|
||||||
|
|
||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
|
|
||||||
for (unsigned int k = 0; k < N; ++k)
|
for (unsigned int k = 0; k < N; ++k)
|
||||||
{
|
{
|
||||||
if (exprs[k]->is_index_reduction())
|
if (exprs[k]->is_index_dot())
|
||||||
stream << exprs[k]->process("#name_buf_value[lid] = #name_acc_value;") << std::endl;
|
stream << exprs[k]->process("#name_buf_value[lid] = #name_acc_value;") << std::endl;
|
||||||
stream << exprs[k]->process("#name_buf[lid] = #name_acc;") << std::endl;
|
stream << exprs[k]->process("#name_buf[lid] = #name_acc;") << std::endl;
|
||||||
}
|
}
|
||||||
@@ -248,7 +249,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
|||||||
stream << "{" << std::endl;
|
stream << "{" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
std::map<std::string, std::string> accessors;
|
std::map<std::string, std::string> accessors;
|
||||||
accessors["scalar_reduction"] = "#name_buf[0]";
|
accessors["scalar_dot"] = "#name_buf[0]";
|
||||||
accessors["array0"] = "#pointer[#start]";
|
accessors["array0"] = "#pointer[#start]";
|
||||||
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
|
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
|
||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
@@ -260,23 +261,23 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
|||||||
return stream.str();
|
return stream.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
reduction::reduction(reduction::parameters_type const & parameters,
|
dot::dot(dot::parameters_type const & parameters,
|
||||||
binding_policy_t binding) : base_impl<reduction, reduction_parameters>(parameters, binding)
|
binding_policy_t binding) : base_impl<dot, dot_parameters>(parameters, binding)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
reduction::reduction(unsigned int simd, unsigned int ls, unsigned int ng,
|
dot::dot(unsigned int simd, unsigned int ls, unsigned int ng,
|
||||||
fetching_policy_type fetch, binding_policy_t bind):
|
fetching_policy_type fetch, binding_policy_t bind):
|
||||||
base_impl<reduction, reduction_parameters>(reduction_parameters(simd,ls,ng,fetch), bind)
|
base_impl<dot, dot_parameters>(dot_parameters(simd,ls,ng,fetch), bind)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
std::vector<int_t> reduction::input_sizes(expressions_tuple const & expressions) const
|
std::vector<int_t> dot::input_sizes(expressions_tuple const & expressions) const
|
||||||
{
|
{
|
||||||
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *(expressions.data().front()), false);
|
std::vector<size_t> dots_idx = filter_nodes(&is_dot, *(expressions.data().front()), false);
|
||||||
int_t N = vector_size(lhs_most(expressions.data().front()->tree(), reductions_idx[0]));
|
int_t N = vector_size(lhs_most(expressions.data().front()->tree(), dots_idx[0]));
|
||||||
return tools::make_vector<int_t>() << N;
|
return tools::make_vector<int_t>() << N;
|
||||||
}
|
}
|
||||||
|
|
||||||
void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
void dot::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
||||||
{
|
{
|
||||||
expressions_tuple const & expressions = controller.x();
|
expressions_tuple const & expressions = controller.x();
|
||||||
|
|
||||||
@@ -290,12 +291,12 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array_expression::node const *> reductions;
|
std::vector<array_expression::node const *> dots;
|
||||||
for (const auto & elem : expressions.data())
|
for (const auto & elem : expressions.data())
|
||||||
{
|
{
|
||||||
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *elem, false);
|
std::vector<size_t> dots_idx = filter_nodes(&is_dot, *elem, false);
|
||||||
for (auto & reductions_idx_itt : reductions_idx)
|
for (auto & dots_idx_itt : dots_idx)
|
||||||
reductions.push_back(&(elem)->tree()[reductions_idx_itt]);
|
dots.push_back(&(elem)->tree()[dots_idx_itt]);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Kernel
|
//Kernel
|
||||||
@@ -321,9 +322,9 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
|
|||||||
//Temporary buffers
|
//Temporary buffers
|
||||||
unsigned int i = 0;
|
unsigned int i = 0;
|
||||||
unsigned int j = 0;
|
unsigned int j = 0;
|
||||||
for (std::vector<array_expression::node const *>::const_iterator it = reductions.begin(); it != reductions.end(); ++it)
|
for (std::vector<array_expression::node const *>::const_iterator it = dots.begin(); it != dots.end(); ++it)
|
||||||
{
|
{
|
||||||
if (is_index_reduction((*it)->op))
|
if (is_index_dot((*it)->op))
|
||||||
{
|
{
|
||||||
if (tmpidx_.size() <= j)
|
if (tmpidx_.size() <= j)
|
||||||
tmpidx_.push_back(driver::Buffer(context, p_.num_groups*4));
|
tmpidx_.push_back(driver::Buffer(context, p_.num_groups*4));
|
||||||
@@ -343,3 +344,4 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
}
|
@@ -1,5 +1,5 @@
|
|||||||
#include "isaac/array.h"
|
#include "isaac/array.h"
|
||||||
#include "isaac/backend/templates/mproduct.h"
|
#include "isaac/backend/templates/gemm.h"
|
||||||
#include "isaac/backend/keywords.h"
|
#include "isaac/backend/keywords.h"
|
||||||
#include "isaac/model/model.h"
|
#include "isaac/model/model.h"
|
||||||
#include "isaac/symbolic/preset.h"
|
#include "isaac/symbolic/preset.h"
|
||||||
@@ -10,8 +10,10 @@
|
|||||||
|
|
||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
namespace templates
|
||||||
|
{
|
||||||
|
|
||||||
mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||||
, int_t local_size_0, int_t KL, int_t local_size_1, int_t D
|
, int_t local_size_0, int_t KL, int_t local_size_1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns
|
, int_t ms, int_t ks, int_t ns
|
||||||
, fetching_policy_type A_fetching_policy, fetching_policy_type B_fetching_policy
|
, fetching_policy_type A_fetching_policy, fetching_policy_type B_fetching_policy
|
||||||
@@ -21,7 +23,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
mL(ms*local_size_0), nL(ns*local_size_1){}
|
mL(ms*local_size_0), nL(ns*local_size_1){}
|
||||||
|
|
||||||
|
|
||||||
unsigned int mproduct::lmem_usage(expressions_tuple const & expressions) const
|
unsigned int gemm::lmem_usage(expressions_tuple const & expressions) const
|
||||||
{
|
{
|
||||||
isaac::array_expression const & array_expression = (*expressions.data().front());
|
isaac::array_expression const & array_expression = (*expressions.data().front());
|
||||||
numeric_type numeric_t = lhs_most(array_expression.tree(), array_expression.root()).lhs.dtype;
|
numeric_type numeric_t = lhs_most(array_expression.tree(), array_expression.root()).lhs.dtype;
|
||||||
@@ -32,7 +34,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
return N*size_of(numeric_t);
|
return N*size_of(numeric_t);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned int mproduct::registers_usage(expressions_tuple const & expressions) const
|
unsigned int gemm::registers_usage(expressions_tuple const & expressions) const
|
||||||
{
|
{
|
||||||
isaac::array_expression const & array_expression = (*expressions.data().front());
|
isaac::array_expression const & array_expression = (*expressions.data().front());
|
||||||
numeric_type numeric_t = lhs_most(array_expression.tree(), array_expression.root()).lhs.dtype;
|
numeric_type numeric_t = lhs_most(array_expression.tree(), array_expression.root()).lhs.dtype;
|
||||||
@@ -41,7 +43,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
return N*size_of(numeric_t);
|
return N*size_of(numeric_t);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mproduct::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const
|
int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const
|
||||||
{
|
{
|
||||||
std::vector<int_t> MNK = input_sizes(expressions);
|
std::vector<int_t> MNK = input_sizes(expressions);
|
||||||
int_t M = MNK[0]; int_t N = MNK[1];
|
int_t M = MNK[0]; int_t N = MNK[1];
|
||||||
@@ -95,7 +97,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
return TEMPLATE_VALID;
|
return TEMPLATE_VALID;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string mproduct::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const &) const
|
std::string gemm::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const &) const
|
||||||
{
|
{
|
||||||
using std::string;
|
using std::string;
|
||||||
using tools::to_string;
|
using tools::to_string;
|
||||||
@@ -437,7 +439,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
#undef VST0RE
|
#undef VST0RE
|
||||||
}
|
}
|
||||||
|
|
||||||
void mproduct::enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K,
|
void gemm::enqueue_block(driver::CommandQueue & /*queue*/, int_t M, int_t N, int_t K,
|
||||||
array const & A, array const & B, array const & C,
|
array const & A, array const & B, array const & C,
|
||||||
value_scalar const & alpha, value_scalar const & beta,
|
value_scalar const & alpha, value_scalar const & beta,
|
||||||
driver::Program & program, const char * suffix, execution_options_type const & options)
|
driver::Program & program, const char * suffix, execution_options_type const & options)
|
||||||
@@ -516,7 +518,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array mproduct::create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap)
|
array gemm::create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap)
|
||||||
{
|
{
|
||||||
slice s0(s0_0, s0_1);
|
slice s0(s0_0, s0_1);
|
||||||
slice s1(s1_0, s1_1);
|
slice s1(s1_0, s1_1);
|
||||||
@@ -525,7 +527,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
return array(M, s0, s1);
|
return array(M, s0, s1);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int_t> mproduct::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments) const
|
std::vector<int_t> gemm::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments) const
|
||||||
{
|
{
|
||||||
isaac::array_expression & array_expression = (*expressions.data().front());
|
isaac::array_expression & array_expression = (*expressions.data().front());
|
||||||
array_expression::container_type & array = array_expression.tree();
|
array_expression::container_type & array = array_expression.tree();
|
||||||
@@ -537,26 +539,26 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
return {M, N, K};
|
return {M, N, K};
|
||||||
}
|
}
|
||||||
|
|
||||||
mproduct::mproduct(mproduct_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<mproduct, mproduct_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
|
gemm::gemm(gemm_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<gemm, gemm_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
|
||||||
{
|
{
|
||||||
if(A_trans_=='N' && B_trans_=='N') type_ = MATRIX_PRODUCT_NN_TYPE;
|
if(A_trans_=='N' && B_trans_=='N') type_ = GEMM_NN_TYPE;
|
||||||
else if(A_trans_=='T' && B_trans_=='N') type_ = MATRIX_PRODUCT_TN_TYPE;
|
else if(A_trans_=='T' && B_trans_=='N') type_ = GEMM_TN_TYPE;
|
||||||
else if(A_trans_=='N' && B_trans_=='T') type_ = MATRIX_PRODUCT_NT_TYPE;
|
else if(A_trans_=='N' && B_trans_=='T') type_ = GEMM_NT_TYPE;
|
||||||
else if(A_trans_=='T' && B_trans_=='T') type_ = MATRIX_PRODUCT_TT_TYPE;
|
else if(A_trans_=='T' && B_trans_=='T') type_ = GEMM_TT_TYPE;
|
||||||
else throw;
|
else throw;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int_t> mproduct::input_sizes(expressions_tuple const & expressions) const
|
std::vector<int_t> gemm::input_sizes(expressions_tuple const & expressions) const
|
||||||
{
|
{
|
||||||
symbolic::preset::gemm::args dummy;
|
symbolic::preset::gemm::args dummy;
|
||||||
return infos(expressions, dummy);
|
return infos(expressions, dummy);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mproduct::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
|
void gemm::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
|
||||||
{
|
{
|
||||||
using namespace tools;
|
using namespace tools;
|
||||||
|
|
||||||
mproduct & fallback = (mproduct&)fallback_base;
|
gemm & fallback = (gemm&)fallback_base;
|
||||||
expressions_tuple const & expressions = ctr.x();
|
expressions_tuple const & expressions = ctr.x();
|
||||||
|
|
||||||
|
|
||||||
@@ -579,8 +581,6 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
int_t ldstrideA = pA->stride()[0];
|
int_t ldstrideA = pA->stride()[0];
|
||||||
int_t ldstrideB = pB->stride()[0];
|
int_t ldstrideB = pB->stride()[0];
|
||||||
int_t ldstrideC = pC->stride()[0];
|
int_t ldstrideC = pC->stride()[0];
|
||||||
int_t ldstartA = pA->start()[0];
|
|
||||||
int_t ldstartB = pB->start()[0];
|
|
||||||
|
|
||||||
numeric_type dtype = args.C->dtype;
|
numeric_type dtype = args.C->dtype;
|
||||||
|
|
||||||
@@ -613,40 +613,41 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
mproduct_nn::mproduct_nn(unsigned int simd
|
gemm_nn::gemm_nn(unsigned int simd
|
||||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns
|
, int_t ms, int_t ks, int_t ns
|
||||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||||
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N')
|
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N')
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
//
|
//
|
||||||
mproduct_tn::mproduct_tn(unsigned int simd
|
gemm_tn::gemm_tn(unsigned int simd
|
||||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns
|
, int_t ms, int_t ks, int_t ns
|
||||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||||
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N')
|
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N')
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
//
|
//
|
||||||
mproduct_nt::mproduct_nt(unsigned int simd
|
gemm_nt::gemm_nt(unsigned int simd
|
||||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns
|
, int_t ms, int_t ks, int_t ns
|
||||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||||
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T')
|
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T')
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
//
|
//
|
||||||
mproduct_tt::mproduct_tt(unsigned int simd
|
gemm_tt::gemm_tt(unsigned int simd
|
||||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns
|
, int_t ms, int_t ks, int_t ns
|
||||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||||
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T')
|
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T')
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@@ -1,48 +1,50 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include "isaac/backend/stream.h"
|
#include "isaac/backend/stream.h"
|
||||||
#include "isaac/backend/keywords.h"
|
#include "isaac/backend/keywords.h"
|
||||||
#include "isaac/backend/templates/mreduction.h"
|
#include "isaac/backend/templates/gemv.h"
|
||||||
#include "isaac/tools/to_string.hpp"
|
#include "isaac/tools/to_string.hpp"
|
||||||
#include "isaac/tools/make_map.hpp"
|
#include "isaac/tools/make_map.hpp"
|
||||||
#include "isaac/tools/make_vector.hpp"
|
#include "isaac/tools/make_vector.hpp"
|
||||||
|
|
||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
namespace templates
|
||||||
|
{
|
||||||
|
|
||||||
mreduction_parameters::mreduction_parameters(unsigned int _simd_width,
|
gemv_parameters::gemv_parameters(unsigned int _simd_width,
|
||||||
unsigned int _local_size_0, unsigned int _local_size_1,
|
unsigned int _local_size_0, unsigned int _local_size_1,
|
||||||
unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetch_policy): base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1),
|
unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetch_policy): base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1),
|
||||||
num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetch_policy(_fetch_policy) { }
|
num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetch_policy(_fetch_policy) { }
|
||||||
|
|
||||||
|
|
||||||
int mreduction::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
int gemv::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||||
{
|
{
|
||||||
if(reduction_type_==REDUCE_ROWS && p_.simd_width>1)
|
if(dot_type_==REDUCE_ROWS && p_.simd_width>1)
|
||||||
return TEMPLATE_INVALID_SIMD_WIDTH;
|
return TEMPLATE_INVALID_SIMD_WIDTH;
|
||||||
if (p_.fetch_policy==FETCH_FROM_LOCAL)
|
if (p_.fetch_policy==FETCH_FROM_LOCAL)
|
||||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||||
return TEMPLATE_VALID;
|
return TEMPLATE_VALID;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned int mreduction::lmem_usage() const
|
unsigned int gemv::lmem_usage() const
|
||||||
{
|
{
|
||||||
return (p_.local_size_0+1)*p_.local_size_1;
|
return (p_.local_size_0+1)*p_.local_size_1;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string mreduction::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
std::string gemv::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
||||||
{
|
{
|
||||||
using tools::to_string;
|
using tools::to_string;
|
||||||
|
|
||||||
|
|
||||||
std::vector<mapped_mreduction*> reductions;
|
std::vector<mapped_gemv*> dots;
|
||||||
expressions_tuple::data_type::const_iterator sit;
|
expressions_tuple::data_type::const_iterator sit;
|
||||||
std::vector<mapping_type>::const_iterator mit;
|
std::vector<mapping_type>::const_iterator mit;
|
||||||
for (mit = mappings.begin(), sit = expressions.data().begin(); mit != mappings.end(); ++mit, ++sit)
|
for (mit = mappings.begin(), sit = expressions.data().begin(); mit != mappings.end(); ++mit, ++sit)
|
||||||
{
|
{
|
||||||
array_expression const & first_expression = *expressions.data().front();
|
array_expression const & first_expression = *expressions.data().front();
|
||||||
std::vector<size_t> idx = filter_nodes(&is_reduction, first_expression, false);
|
std::vector<size_t> idx = filter_nodes(&is_dot, first_expression, false);
|
||||||
for (auto & elem : idx)
|
for (auto & elem : idx)
|
||||||
reductions.push_back((mapped_mreduction*)(mit->at(mapping_key(elem, PARENT_NODE_TYPE)).get()));
|
dots.push_back((mapped_gemv*)(mit->at(mapping_key(elem, PARENT_NODE_TYPE)).get()));
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel_generation_stream stream;
|
kernel_generation_stream stream;
|
||||||
@@ -54,10 +56,10 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
|||||||
strcat(name[1], suffix);
|
strcat(name[1], suffix);
|
||||||
|
|
||||||
std::string arguments = _size_t + " M, " + _size_t + " N, " ;
|
std::string arguments = _size_t + " M, " + _size_t + " N, " ;
|
||||||
for (const auto & e : reductions)
|
for (const auto & e : dots)
|
||||||
{
|
{
|
||||||
std::string numeric_type = numeric_type_to_string(lhs_most(e->array_expression().tree(), e->array_expression().root()).lhs.dtype);
|
std::string numeric_type = numeric_type_to_string(lhs_most(e->array_expression().tree(), e->array_expression().root()).lhs.dtype);
|
||||||
if (e->is_index_reduction())
|
if (e->is_index_dot())
|
||||||
{
|
{
|
||||||
arguments += e->process(Global(backend).get() + " unsigned int* #name_temp, ");
|
arguments += e->process(Global(backend).get() + " unsigned int* #name_temp, ");
|
||||||
arguments += e->process(Global(backend).get() + " " + to_string(numeric_type) + "* #name_temp_value,");
|
arguments += e->process(Global(backend).get() + " " + to_string(numeric_type) + "* #name_temp_value,");
|
||||||
@@ -87,7 +89,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
|||||||
unsigned int local_size_0_ld = p_.local_size_0;
|
unsigned int local_size_0_ld = p_.local_size_0;
|
||||||
std::string local_size_0_ld_str = to_string(local_size_0_ld);
|
std::string local_size_0_ld_str = to_string(local_size_0_ld);
|
||||||
|
|
||||||
for (const auto & e : reductions)
|
for (const auto & e : dots)
|
||||||
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
|
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
|
||||||
|
|
||||||
stream << "" << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl;
|
stream << "" << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl;
|
||||||
@@ -104,7 +106,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
|||||||
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
|
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
|
|
||||||
for (const auto & e : reductions)
|
for (const auto & e : dots)
|
||||||
stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl;
|
stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl;
|
||||||
|
|
||||||
stream << "if (r < M)" << std::endl;
|
stream << "if (r < M)" << std::endl;
|
||||||
@@ -116,10 +118,10 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
|||||||
std::string data_type = append_width("#scalartype",simd_width);
|
std::string data_type = append_width("#scalartype",simd_width);
|
||||||
|
|
||||||
|
|
||||||
for (const auto & e : reductions)
|
for (const auto & e : dots)
|
||||||
{
|
{
|
||||||
std::map<std::string, std::string> accessors;
|
std::map<std::string, std::string> accessors;
|
||||||
if(reduction_type_==REDUCE_COLUMNS)
|
if(dot_type_==REDUCE_COLUMNS)
|
||||||
{
|
{
|
||||||
accessors["array2"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "c*#stride1", "#pointer + r*#ld", backend)+";";
|
accessors["array2"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "c*#stride1", "#pointer + r*#ld", backend)+";";
|
||||||
accessors["repeat"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "(c%#tuplearg0)*#stride", "#pointer + (r%#tuplearg1)*#stride ", backend)+";";
|
accessors["repeat"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "(c%#tuplearg0)*#stride", "#pointer + (r%#tuplearg1)*#stride ", backend)+";";
|
||||||
@@ -141,20 +143,20 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
|||||||
str[a] = access_vector_type("#namereg",a);
|
str[a] = access_vector_type("#namereg",a);
|
||||||
|
|
||||||
|
|
||||||
for (auto & elem : reductions)
|
for (auto & elem : dots)
|
||||||
for (unsigned int a = 0; a < simd_width; ++a)
|
for (unsigned int a = 0; a < simd_width; ++a)
|
||||||
{
|
{
|
||||||
std::string value = elem->evaluate_recursive(LHS_NODE_TYPE, {{"array2", str[a]}, {"repeat", str[a]}, {"array0", "#namereg"}});
|
std::string value = elem->evaluate_recursive(LHS_NODE_TYPE, {{"array2", str[a]}, {"repeat", str[a]}, {"array0", "#namereg"}});
|
||||||
if (elem->is_index_reduction())
|
if (elem->is_index_dot())
|
||||||
compute_index_reduction(stream, elem->process("#name_acc"), "c*"+to_string(simd_width) + to_string(a), elem->process("#name_acc_value"), value, elem->root_op());
|
compute_index_dot(stream, elem->process("#name_acc"), "c*"+to_string(simd_width) + to_string(a), elem->process("#name_acc_value"), value, elem->root_op());
|
||||||
else
|
else
|
||||||
compute_reduction(stream, elem->process("#name_acc"), value,elem->root_op());
|
compute_dot(stream, elem->process("#name_acc"), value,elem->root_op());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
|
|
||||||
for (auto & expr : reductions)
|
for (auto & expr : dots)
|
||||||
stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl;
|
stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl;
|
||||||
|
|
||||||
stream << "#pragma unroll" << std::endl;
|
stream << "#pragma unroll" << std::endl;
|
||||||
@@ -167,13 +169,13 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
|||||||
stream << "{" << std::endl;
|
stream << "{" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
|
|
||||||
for (auto & e : reductions)
|
for (auto & e : dots)
|
||||||
if (e->is_index_reduction())
|
if (e->is_index_dot())
|
||||||
compute_index_reduction(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
|
compute_index_dot(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
|
||||||
, e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
|
, e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
|
||||||
, e->root_op());
|
, e->root_op());
|
||||||
else
|
else
|
||||||
compute_reduction(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
|
compute_dot(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
|
||||||
|
|
||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
@@ -188,15 +190,15 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
|||||||
if(p_.num_groups_0==1)
|
if(p_.num_groups_0==1)
|
||||||
{
|
{
|
||||||
std::map<std::string, std::string> accessors;
|
std::map<std::string, std::string> accessors;
|
||||||
accessors["mreduction"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
|
accessors["gemv"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
|
||||||
accessors["array1"] = "#pointer[r*#stride]";
|
accessors["array1"] = "#pointer[r*#stride]";
|
||||||
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
|
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
for (mapped_reduction const * e : reductions)
|
for (mapped_dot const * e : dots)
|
||||||
{
|
{
|
||||||
if (e->is_index_reduction())
|
if (e->is_index_dot())
|
||||||
stream << e->process("#name_temp_value[r + M*gpid0] = #name_buf_value[lid1*" + local_size_0_ld_str + "];") << std::endl;
|
stream << e->process("#name_temp_value[r + M*gpid0] = #name_buf_value[lid1*" + local_size_0_ld_str + "];") << std::endl;
|
||||||
stream << e->process("#name_temp[r + M*gpid0] = #name_buf[lid1*" + local_size_0_ld_str + "];") << std::endl;
|
stream << e->process("#name_temp[r + M*gpid0] = #name_buf[lid1*" + local_size_0_ld_str + "];") << std::endl;
|
||||||
}
|
}
|
||||||
@@ -230,7 +232,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
|||||||
{"array2", "#pointer += #start1 + #start2*#ld; "
|
{"array2", "#pointer += #start1 + #start2*#ld; "
|
||||||
"#ld *= #nldstride; "}}, expressions, mappings);
|
"#ld *= #nldstride; "}}, expressions, mappings);
|
||||||
|
|
||||||
for (const auto & e : reductions)
|
for (const auto & e : dots)
|
||||||
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
|
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
|
||||||
|
|
||||||
stream << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl;
|
stream << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl;
|
||||||
@@ -246,7 +248,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
|||||||
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
|
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
|
|
||||||
for (const auto & e : reductions)
|
for (const auto & e : dots)
|
||||||
stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl;
|
stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl;
|
||||||
|
|
||||||
stream << "if (r < M)" << std::endl;
|
stream << "if (r < M)" << std::endl;
|
||||||
@@ -256,8 +258,8 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
|||||||
stream << "for(" << _size_t << " c = lid0; c < " << p_.num_groups_0 << "; c += lsize0){" << std::endl;
|
stream << "for(" << _size_t << " c = lid0; c < " << p_.num_groups_0 << "; c += lsize0){" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
|
|
||||||
for (mapped_reduction* e: reductions)
|
for (mapped_dot* e: dots)
|
||||||
compute_reduction(stream, e->process("#name_acc"), e->process("#name_temp[r + M*c]"), e->root_op());
|
compute_dot(stream, e->process("#name_acc"), e->process("#name_temp[r + M*c]"), e->root_op());
|
||||||
|
|
||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
@@ -266,7 +268,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
|||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
|
|
||||||
for (auto & expr : reductions)
|
for (auto & expr : dots)
|
||||||
stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl;
|
stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl;
|
||||||
|
|
||||||
stream << "#pragma unroll" << std::endl;
|
stream << "#pragma unroll" << std::endl;
|
||||||
@@ -279,13 +281,13 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
|||||||
stream << "{" << std::endl;
|
stream << "{" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
|
|
||||||
for (auto & e : reductions)
|
for (auto & e : dots)
|
||||||
if (e->is_index_reduction())
|
if (e->is_index_dot())
|
||||||
compute_index_reduction(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
|
compute_index_dot(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
|
||||||
, e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
|
, e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
|
||||||
, e->root_op());
|
, e->root_op());
|
||||||
else
|
else
|
||||||
compute_reduction(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
|
compute_dot(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
|
||||||
|
|
||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
@@ -299,7 +301,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
|||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
|
|
||||||
std::map<std::string, std::string> accessors;
|
std::map<std::string, std::string> accessors;
|
||||||
accessors["mreduction"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
|
accessors["gemv"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
|
||||||
accessors["array1"] = "#pointer[r*#stride]";
|
accessors["array1"] = "#pointer[r*#stride]";
|
||||||
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
|
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
|
||||||
|
|
||||||
@@ -317,38 +319,38 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
|||||||
return stream.str();
|
return stream.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
mreduction::mreduction(mreduction::parameters_type const & parameters,
|
gemv::gemv(gemv::parameters_type const & parameters,
|
||||||
mreduction::reduction_type rtype,
|
gemv::dot_type rtype,
|
||||||
binding_policy_t binding_policy) :
|
binding_policy_t binding_policy) :
|
||||||
base_impl<mreduction, mreduction_parameters>(parameters, binding_policy),
|
base_impl<gemv, gemv_parameters>(parameters, binding_policy),
|
||||||
reduction_type_(rtype){ }
|
dot_type_(rtype){ }
|
||||||
|
|
||||||
std::vector<int_t> mreduction::input_sizes(expressions_tuple const & expressions) const
|
std::vector<int_t> gemv::input_sizes(expressions_tuple const & expressions) const
|
||||||
{
|
{
|
||||||
array_expression const & first_expression = *expressions.data().front();
|
array_expression const & first_expression = *expressions.data().front();
|
||||||
std::vector<std::size_t> idx = filter_nodes(&is_reduction, first_expression, false);
|
std::vector<std::size_t> idx = filter_nodes(&is_dot, first_expression, false);
|
||||||
std::pair<int_t, int_t> MN = matrix_size(lhs_most(first_expression.tree(), idx[0]));
|
std::pair<int_t, int_t> MN = matrix_size(lhs_most(first_expression.tree(), idx[0]));
|
||||||
if(reduction_type_==REDUCE_COLUMNS)
|
if(dot_type_==REDUCE_COLUMNS)
|
||||||
std::swap(MN.first,MN.second);
|
std::swap(MN.first,MN.second);
|
||||||
return tools::make_vector<int_t>() << MN.first << MN.second;
|
return tools::make_vector<int_t>() << MN.first << MN.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
void gemv::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
||||||
{
|
{
|
||||||
expressions_tuple const & expressions = controller.x();
|
expressions_tuple const & expressions = controller.x();
|
||||||
driver::Context const & context = expressions.context();
|
driver::Context const & context = expressions.context();
|
||||||
|
|
||||||
std::vector<int_t> MN = input_sizes(expressions);
|
std::vector<int_t> MN = input_sizes(expressions);
|
||||||
std::vector<array_expression::node const *> reductions;
|
std::vector<array_expression::node const *> dots;
|
||||||
for (const auto & e : expressions.data())
|
for (const auto & e : expressions.data())
|
||||||
{
|
{
|
||||||
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *e, false);
|
std::vector<size_t> dots_idx = filter_nodes(&is_dot, *e, false);
|
||||||
for (auto & r : reductions_idx)
|
for (auto & r : dots_idx)
|
||||||
reductions.push_back(&(e)->tree()[r]);
|
dots.push_back(&(e)->tree()[r]);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Fallback
|
//Fallback
|
||||||
if(reduction_type_==REDUCE_COLUMNS && p_.simd_width>1 && requires_fallback(expressions))
|
if(dot_type_==REDUCE_COLUMNS && p_.simd_width>1 && requires_fallback(expressions))
|
||||||
{
|
{
|
||||||
fallback.enqueue(queue, program, "fallback", fallback, controller);
|
fallback.enqueue(queue, program, "fallback", fallback, controller);
|
||||||
return;
|
return;
|
||||||
@@ -381,9 +383,9 @@ void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program
|
|||||||
//Temporary buffers
|
//Temporary buffers
|
||||||
unsigned int i = 0;
|
unsigned int i = 0;
|
||||||
unsigned int j = 0;
|
unsigned int j = 0;
|
||||||
for (auto const & r : reductions)
|
for (auto const & r : dots)
|
||||||
{
|
{
|
||||||
if (is_index_reduction(r->op))
|
if (is_index_dot(r->op))
|
||||||
{
|
{
|
||||||
if (tmpidx.size() <= j)
|
if (tmpidx.size() <= j)
|
||||||
tmpidx.push_back(driver::Buffer(context, p_.num_groups_0*M*4));
|
tmpidx.push_back(driver::Buffer(context, p_.num_groups_0*M*4));
|
||||||
@@ -405,24 +407,25 @@ void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program
|
|||||||
controller.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
|
controller.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
mreduction_rows::mreduction_rows(mreduction_parameters const & parameters,
|
gemv_n::gemv_n(gemv_parameters const & parameters,
|
||||||
binding_policy_t binding_policy):
|
binding_policy_t binding_policy):
|
||||||
mreduction(parameters, REDUCE_ROWS, binding_policy){}
|
gemv(parameters, REDUCE_ROWS, binding_policy){}
|
||||||
|
|
||||||
mreduction_rows::mreduction_rows(unsigned int simd, unsigned int ls1, unsigned int ls2,
|
gemv_n::gemv_n(unsigned int simd, unsigned int ls1, unsigned int ls2,
|
||||||
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind):
|
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind):
|
||||||
mreduction(mreduction_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_ROWS, bind)
|
gemv(gemv_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_ROWS, bind)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
|
|
||||||
mreduction_cols::mreduction_cols(mreduction::parameters_type const & parameters,
|
gemv_t::gemv_t(gemv::parameters_type const & parameters,
|
||||||
binding_policy_t binding_policy):
|
binding_policy_t binding_policy):
|
||||||
mreduction(parameters, REDUCE_COLUMNS, binding_policy){}
|
gemv(parameters, REDUCE_COLUMNS, binding_policy){}
|
||||||
|
|
||||||
mreduction_cols::mreduction_cols(unsigned int simd, unsigned int ls1, unsigned int ls2,
|
gemv_t::gemv_t(unsigned int simd, unsigned int ls1, unsigned int ls2,
|
||||||
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind):
|
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind):
|
||||||
mreduction(mreduction_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_COLUMNS, bind)
|
gemv(gemv_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_COLUMNS, bind)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
}
|
@@ -1,4 +1,4 @@
|
|||||||
#include "isaac/backend/templates/maxpy.h"
|
#include "isaac/backend/templates/ger.h"
|
||||||
#include "isaac/tools/make_map.hpp"
|
#include "isaac/tools/make_map.hpp"
|
||||||
#include "isaac/tools/make_vector.hpp"
|
#include "isaac/tools/make_vector.hpp"
|
||||||
#include "isaac/symbolic/io.h"
|
#include "isaac/symbolic/io.h"
|
||||||
@@ -7,15 +7,17 @@
|
|||||||
|
|
||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
namespace templates
|
||||||
|
{
|
||||||
|
|
||||||
maxpy_parameters::maxpy_parameters(unsigned int _simd_width,
|
ger_parameters::ger_parameters(unsigned int _simd_width,
|
||||||
unsigned int _local_size_0, unsigned int _local_size_1,
|
unsigned int _local_size_0, unsigned int _local_size_1,
|
||||||
unsigned int _num_groups_0, unsigned int _num_groups_1,
|
unsigned int _num_groups_0, unsigned int _num_groups_1,
|
||||||
fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1), num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetching_policy(_fetching_policy){ }
|
fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1), num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetching_policy(_fetching_policy){ }
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int maxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
int ger::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||||
{
|
{
|
||||||
if (p_.simd_width>1)
|
if (p_.simd_width>1)
|
||||||
return TEMPLATE_INVALID_SIMD_WIDTH;
|
return TEMPLATE_INVALID_SIMD_WIDTH;
|
||||||
@@ -24,7 +26,7 @@ int maxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) co
|
|||||||
return TEMPLATE_VALID;
|
return TEMPLATE_VALID;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string maxpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
std::string ger::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
||||||
{
|
{
|
||||||
kernel_generation_stream stream;
|
kernel_generation_stream stream;
|
||||||
std::string _size_t = size_type(device);
|
std::string _size_t = size_type(device);
|
||||||
@@ -95,23 +97,23 @@ std::string maxpy::generate_impl(const char * suffix, expressions_tuple const &
|
|||||||
return stream.str();
|
return stream.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
maxpy::maxpy(parameters_type const & parameters, binding_policy_t binding_policy) :
|
ger::ger(parameters_type const & parameters, binding_policy_t binding_policy) :
|
||||||
base_impl<maxpy, maxpy_parameters>(parameters, binding_policy){ }
|
base_impl<ger, ger_parameters>(parameters, binding_policy){ }
|
||||||
|
|
||||||
maxpy::maxpy(unsigned int simd, unsigned int ls1, unsigned int ls2,
|
ger::ger(unsigned int simd, unsigned int ls1, unsigned int ls2,
|
||||||
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch,
|
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch,
|
||||||
binding_policy_t bind):
|
binding_policy_t bind):
|
||||||
base_impl<maxpy, maxpy_parameters>(maxpy_parameters(simd, ls1, ls2, ng1, ng2, fetch), bind)
|
base_impl<ger, ger_parameters>(ger_parameters(simd, ls1, ls2, ng1, ng2, fetch), bind)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
std::vector<int_t> maxpy::input_sizes(expressions_tuple const & expressions) const
|
std::vector<int_t> ger::input_sizes(expressions_tuple const & expressions) const
|
||||||
{
|
{
|
||||||
isaac::array_expression const & array_expression = *(expressions.data().front());
|
isaac::array_expression const & array_expression = *(expressions.data().front());
|
||||||
std::pair<int_t, int_t> size = matrix_size(lhs_most(array_expression.tree(), array_expression.root()));
|
std::pair<int_t, int_t> size = matrix_size(lhs_most(array_expression.tree(), array_expression.root()));
|
||||||
return tools::make_vector<int_t>() << size.first << size.second;
|
return tools::make_vector<int_t>() << size.first << size.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
void maxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base &, controller<expressions_tuple> const & controller)
|
void ger::enqueue(driver::CommandQueue & /*queue*/, driver::Program & program, const char * suffix, base &, controller<expressions_tuple> const & controller)
|
||||||
{
|
{
|
||||||
expressions_tuple const & expressions = controller.x();
|
expressions_tuple const & expressions = controller.x();
|
||||||
char name[32] = {"axpy"};
|
char name[32] = {"axpy"};
|
||||||
@@ -129,3 +131,4 @@ void maxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, con
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
}
|
@@ -7,11 +7,11 @@
|
|||||||
|
|
||||||
#include "rapidjson/document.h"
|
#include "rapidjson/document.h"
|
||||||
#include "isaac/backend/parse.h"
|
#include "isaac/backend/parse.h"
|
||||||
#include "isaac/backend/templates/vaxpy.h"
|
#include "isaac/backend/templates/axpy.h"
|
||||||
#include "isaac/backend/templates/reduction.h"
|
#include "isaac/backend/templates/dot.h"
|
||||||
#include "isaac/backend/templates/maxpy.h"
|
#include "isaac/backend/templates/ger.h"
|
||||||
#include "isaac/backend/templates/mreduction.h"
|
#include "isaac/backend/templates/gemv.h"
|
||||||
#include "isaac/backend/templates/mproduct.h"
|
#include "isaac/backend/templates/gemm.h"
|
||||||
#include "isaac/exception/unknown_datatype.h"
|
#include "isaac/exception/unknown_datatype.h"
|
||||||
#include "isaac/exception/operation_not_supported.h"
|
#include "isaac/exception/operation_not_supported.h"
|
||||||
#include "isaac/model/model.h"
|
#include "isaac/model/model.h"
|
||||||
@@ -86,12 +86,12 @@ driver::Program& model::init(controller<expressions_tuple> const & expressions)
|
|||||||
return *program;
|
return *program;
|
||||||
}
|
}
|
||||||
|
|
||||||
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< tools::shared_ptr<base> > const & templates, driver::CommandQueue const & queue) :
|
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< tools::shared_ptr<templates::base> > const & templates, driver::CommandQueue const & queue) :
|
||||||
templates_(templates), fallback_(fallbacks[std::make_pair(etype, dtype)]), predictor_(new predictors::random_forest(predictor)), queue_(queue)
|
templates_(templates), fallback_(fallbacks[std::make_pair(etype, dtype)]), predictor_(new predictors::random_forest(predictor)), queue_(queue)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
|
|
||||||
model::model(expression_type etype, numeric_type dtype, base const & tp, driver::CommandQueue const & queue) : templates_(1,tp.clone()), fallback_(fallbacks[std::make_pair(etype, dtype)]), queue_(queue)
|
model::model(expression_type etype, numeric_type dtype, templates::base const & tp, driver::CommandQueue const & queue) : templates_(1,tp.clone()), fallback_(fallbacks[std::make_pair(etype, dtype)]), queue_(queue)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
void model::execute(controller<expressions_tuple> const & expr)
|
void model::execute(controller<expressions_tuple> const & expr)
|
||||||
@@ -148,15 +148,15 @@ namespace detail
|
|||||||
{
|
{
|
||||||
static expression_type get_expression_type(std::string const & name)
|
static expression_type get_expression_type(std::string const & name)
|
||||||
{
|
{
|
||||||
if(name=="vaxpy") return VECTOR_AXPY_TYPE;
|
if(name=="axpy") return AXPY_TYPE;
|
||||||
if(name=="dot") return REDUCTION_TYPE;
|
if(name=="dot") return DOT_TYPE;
|
||||||
if(name=="maxpy") return MATRIX_AXPY_TYPE;
|
if(name=="ger") return GER_TYPE;
|
||||||
if(name=="mreduction_rows") return ROW_WISE_REDUCTION_TYPE;
|
if(name=="gemv_n") return GEMV_N_TYPE;
|
||||||
if(name=="mreduction_cols") return COL_WISE_REDUCTION_TYPE;
|
if(name=="gemv_t") return GEMV_T_TYPE;
|
||||||
if(name=="mproduct_nn") return MATRIX_PRODUCT_NN_TYPE;
|
if(name=="gemm_nn") return GEMM_NN_TYPE;
|
||||||
if(name=="mproduct_nt") return MATRIX_PRODUCT_NT_TYPE;
|
if(name=="gemm_nt") return GEMM_NT_TYPE;
|
||||||
if(name=="mproduct_tn") return MATRIX_PRODUCT_TN_TYPE;
|
if(name=="gemm_tn") return GEMM_TN_TYPE;
|
||||||
if(name=="mproduct_tt") return MATRIX_PRODUCT_TT_TYPE;
|
if(name=="gemm_tt") return GEMM_TT_TYPE;
|
||||||
throw std::invalid_argument("Invalid expression: " + name);
|
throw std::invalid_argument("Invalid expression: " + name);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,27 +167,27 @@ namespace detail
|
|||||||
throw std::invalid_argument("Invalid datatype: " + name);
|
throw std::invalid_argument("Invalid datatype: " + name);
|
||||||
}
|
}
|
||||||
|
|
||||||
static tools::shared_ptr<base> create(std::string const & template_name, std::vector<int> const & a)
|
static tools::shared_ptr<templates::base> create(std::string const & template_name, std::vector<int> const & a)
|
||||||
{
|
{
|
||||||
fetching_policy_type fetch[] = {FETCH_FROM_LOCAL, FETCH_FROM_GLOBAL_STRIDED, FETCH_FROM_GLOBAL_CONTIGUOUS};
|
templates::fetching_policy_type fetch[] = {templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_GLOBAL_STRIDED, templates::FETCH_FROM_GLOBAL_CONTIGUOUS};
|
||||||
if(template_name=="vaxpy")
|
if(template_name=="axpy")
|
||||||
return tools::shared_ptr<base>(new vaxpy(a[0], a[1], a[2], fetch[a[3]]));
|
return tools::shared_ptr<templates::base>(new templates::axpy(a[0], a[1], a[2], fetch[a[3]]));
|
||||||
else if(template_name=="dot")
|
else if(template_name=="dot")
|
||||||
return tools::shared_ptr<base>(new reduction(a[0], a[1], a[2], fetch[a[3]]));
|
return tools::shared_ptr<templates::base>(new templates::dot(a[0], a[1], a[2], fetch[a[3]]));
|
||||||
else if(template_name=="maxpy")
|
else if(template_name=="ger")
|
||||||
return tools::shared_ptr<base>(new maxpy(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
return tools::shared_ptr<templates::base>(new templates::ger(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||||
else if(template_name.find("mreduction_rows")!=std::string::npos)
|
else if(template_name.find("gemv_n")!=std::string::npos)
|
||||||
return tools::shared_ptr<base>(new mreduction_rows(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
return tools::shared_ptr<templates::base>(new templates::gemv_n(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||||
else if(template_name.find("mreduction_cols")!=std::string::npos)
|
else if(template_name.find("gemv_t")!=std::string::npos)
|
||||||
return tools::shared_ptr<base>(new mreduction_cols(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
return tools::shared_ptr<templates::base>(new templates::gemv_t(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||||
else if(template_name.find("mproduct_nn")!=std::string::npos)
|
else if(template_name.find("gemm_nn")!=std::string::npos)
|
||||||
return tools::shared_ptr<base>(new mproduct_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
return tools::shared_ptr<templates::base>(new templates::gemm_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
||||||
else if(template_name.find("mproduct_tn")!=std::string::npos)
|
else if(template_name.find("gemm_tn")!=std::string::npos)
|
||||||
return tools::shared_ptr<base>(new mproduct_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
return tools::shared_ptr<templates::base>(new templates::gemm_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
||||||
else if(template_name.find("mproduct_nt")!=std::string::npos)
|
else if(template_name.find("gemm_nt")!=std::string::npos)
|
||||||
return tools::shared_ptr<base>(new mproduct_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
return tools::shared_ptr<templates::base>(new templates::gemm_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
||||||
else if(template_name.find("mproduct_tt")!=std::string::npos)
|
else if(template_name.find("gemm_tt")!=std::string::npos)
|
||||||
return tools::shared_ptr<base>(new mproduct_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
return tools::shared_ptr<templates::base>(new templates::gemm_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
||||||
else
|
else
|
||||||
throw std::invalid_argument("Invalid expression: " + template_name);
|
throw std::invalid_argument("Invalid expression: " + template_name);
|
||||||
}
|
}
|
||||||
@@ -207,7 +207,7 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
|
|||||||
str.assign((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
|
str.assign((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
|
||||||
document.Parse<0>(str.c_str());
|
document.Parse<0>(str.c_str());
|
||||||
//Deserialize
|
//Deserialize
|
||||||
std::vector<std::string> operations = {"vaxpy", "dot", "maxpy", "gemv_n", "gemv_t", "mproduct_nn", "mproduct_tn", "mproduct_nt", "mproduct_tt"};
|
std::vector<std::string> operations = {"axpy", "dot", "ger", "gemv_n", "gemv_t", "gemm_nn", "gemm_tn", "gemm_nt", "gemm_tt"};
|
||||||
std::vector<std::string> dtype = {"float32", "float64"};
|
std::vector<std::string> dtype = {"float32", "float64"};
|
||||||
for(auto & operation : operations)
|
for(auto & operation : operations)
|
||||||
{
|
{
|
||||||
@@ -223,7 +223,7 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
|
|||||||
numeric_type dtype = detail::get_dtype(elem);
|
numeric_type dtype = detail::get_dtype(elem);
|
||||||
|
|
||||||
// Get profiles
|
// Get profiles
|
||||||
std::vector<tools::shared_ptr<base> > templates;
|
std::vector<tools::shared_ptr<templates::base> > templates;
|
||||||
js::Value const & profiles = document[opcstr][dtcstr]["profiles"];
|
js::Value const & profiles = document[opcstr][dtcstr]["profiles"];
|
||||||
for (js::SizeType id = 0 ; id < profiles.Size() ; ++id)
|
for (js::SizeType id = 0 ; id < profiles.Size() ; ++id)
|
||||||
templates.push_back(detail::create(operation, tools::to_int_array<int>(profiles[id])));
|
templates.push_back(detail::create(operation, tools::to_int_array<int>(profiles[id])));
|
||||||
@@ -243,23 +243,22 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<base> > init_fallback()
|
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > init_fallback()
|
||||||
{
|
{
|
||||||
typedef tools::shared_ptr<base> ptr_t;
|
typedef tools::shared_ptr<templates::base> ptr_t;
|
||||||
std::map<std::pair<expression_type, numeric_type>, ptr_t > res;
|
std::map<std::pair<expression_type, numeric_type>, ptr_t > res;
|
||||||
numeric_type types[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
|
numeric_type types[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
|
||||||
for(auto DTYPE : types)
|
for(auto DTYPE : types)
|
||||||
{
|
{
|
||||||
res[std::make_pair(SCALAR_AXPY_TYPE, DTYPE)] = ptr_t(new vaxpy(1,64,128,FETCH_FROM_GLOBAL_STRIDED));
|
res[std::make_pair(AXPY_TYPE, DTYPE)] = ptr_t (new templates::axpy(1,64,128,templates::FETCH_FROM_GLOBAL_STRIDED));
|
||||||
res[std::make_pair(VECTOR_AXPY_TYPE, DTYPE)] = ptr_t (new vaxpy(1,64,128,FETCH_FROM_GLOBAL_STRIDED));
|
res[std::make_pair(DOT_TYPE, DTYPE)] = ptr_t(new templates::dot(1,64,128,templates::FETCH_FROM_GLOBAL_STRIDED));
|
||||||
res[std::make_pair(REDUCTION_TYPE, DTYPE)] = ptr_t(new reduction(1,64,128,FETCH_FROM_GLOBAL_STRIDED));
|
res[std::make_pair(GER_TYPE, DTYPE)] = ptr_t(new templates::ger(1,8,8,8,8,templates::FETCH_FROM_GLOBAL_STRIDED));
|
||||||
res[std::make_pair(MATRIX_AXPY_TYPE, DTYPE)] = ptr_t(new maxpy(1,8,8,8,8,FETCH_FROM_GLOBAL_STRIDED));
|
res[std::make_pair(GEMV_N_TYPE, DTYPE)] = ptr_t(new templates::gemv_n(1, 8, 8, 4, 16, templates::FETCH_FROM_GLOBAL_STRIDED));
|
||||||
res[std::make_pair(ROW_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new mreduction_rows(1, 8, 8, 4, 16, FETCH_FROM_GLOBAL_STRIDED));
|
res[std::make_pair(GEMV_T_TYPE, DTYPE)] = ptr_t(new templates::gemv_t(1, 8, 8, 64, 8, templates::FETCH_FROM_GLOBAL_STRIDED));
|
||||||
res[std::make_pair(COL_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new mreduction_cols(1, 8, 8, 64, 8, FETCH_FROM_GLOBAL_STRIDED));
|
res[std::make_pair(GEMM_NN_TYPE, DTYPE)] = ptr_t(new templates::gemm_nn(1, 8, 32, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
|
||||||
res[std::make_pair(MATRIX_PRODUCT_NN_TYPE, DTYPE)] = ptr_t(new mproduct_nn(1, 8, 32, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
|
res[std::make_pair(GEMM_TN_TYPE, DTYPE)] = ptr_t(new templates::gemm_tn(1, 8, 32, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
|
||||||
res[std::make_pair(MATRIX_PRODUCT_TN_TYPE, DTYPE)] = ptr_t(new mproduct_tn(1, 8, 32, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
|
res[std::make_pair(GEMM_NT_TYPE, DTYPE)] = ptr_t(new templates::gemm_nt(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
|
||||||
res[std::make_pair(MATRIX_PRODUCT_NT_TYPE, DTYPE)] = ptr_t(new mproduct_nt(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
|
res[std::make_pair(GEMM_TT_TYPE, DTYPE)] = ptr_t(new templates::gemm_tt(1, 8, 32, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
|
||||||
res[std::make_pair(MATRIX_PRODUCT_TT_TYPE, DTYPE)] = ptr_t(new mproduct_tt(1, 8, 32, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
|
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@@ -269,7 +268,7 @@ model_map_t init_models(driver::CommandQueue & queue)
|
|||||||
{
|
{
|
||||||
model_map_t res;
|
model_map_t res;
|
||||||
numeric_type dtypes[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
|
numeric_type dtypes[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
|
||||||
expression_type etypes[] = {SCALAR_AXPY_TYPE, VECTOR_AXPY_TYPE, REDUCTION_TYPE, MATRIX_AXPY_TYPE, ROW_WISE_REDUCTION_TYPE, COL_WISE_REDUCTION_TYPE, MATRIX_PRODUCT_NN_TYPE, MATRIX_PRODUCT_NT_TYPE, MATRIX_PRODUCT_TN_TYPE, MATRIX_PRODUCT_TT_TYPE};
|
expression_type etypes[] = {AXPY_TYPE, DOT_TYPE, GER_TYPE, GEMV_N_TYPE, GEMV_T_TYPE, GEMM_NN_TYPE, GEMM_NT_TYPE, GEMM_TN_TYPE, GEMM_TT_TYPE};
|
||||||
|
|
||||||
for(numeric_type dtype: dtypes)
|
for(numeric_type dtype: dtypes)
|
||||||
for(expression_type etype: etypes)
|
for(expression_type etype: etypes)
|
||||||
@@ -288,7 +287,7 @@ model_map_t& models(driver::CommandQueue & queue)
|
|||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<base> > fallbacks = init_fallback();
|
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > fallbacks = init_fallback();
|
||||||
std::map<driver::CommandQueue, model_map_t> models_;
|
std::map<driver::CommandQueue, model_map_t> models_;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -19,13 +19,13 @@ namespace isaac
|
|||||||
|
|
||||||
inline bool is_mmprod(expression_type x)
|
inline bool is_mmprod(expression_type x)
|
||||||
{
|
{
|
||||||
return x==MATRIX_PRODUCT_NN_TYPE || x==MATRIX_PRODUCT_NT_TYPE ||
|
return x==GEMM_NN_TYPE || x==GEMM_NT_TYPE ||
|
||||||
x==MATRIX_PRODUCT_TN_TYPE || x==MATRIX_PRODUCT_TT_TYPE;
|
x==GEMM_TN_TYPE || x==GEMM_TT_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool is_mvprod(expression_type x)
|
inline bool is_mvprod(expression_type x)
|
||||||
{
|
{
|
||||||
return x==ROW_WISE_REDUCTION_TYPE || x==COL_WISE_REDUCTION_TYPE;
|
return x==GEMV_N_TYPE || x==GEMV_T_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool has_temporary_impl(op_element op, expression_type expression, expression_type other, bool is_first)
|
inline bool has_temporary_impl(op_element op, expression_type expression, expression_type other, bool is_first)
|
||||||
@@ -36,27 +36,27 @@ namespace isaac
|
|||||||
case OPERATOR_UNARY_TYPE_FAMILY:
|
case OPERATOR_UNARY_TYPE_FAMILY:
|
||||||
case OPERATOR_BINARY_TYPE_FAMILY:
|
case OPERATOR_BINARY_TYPE_FAMILY:
|
||||||
result |= is_mmprod(expression)
|
result |= is_mmprod(expression)
|
||||||
|| (result |= expression==ROW_WISE_REDUCTION_TYPE && other==COL_WISE_REDUCTION_TYPE)
|
|| (result |= expression==GEMV_N_TYPE && other==GEMV_T_TYPE)
|
||||||
|| (result |= expression==COL_WISE_REDUCTION_TYPE && other==ROW_WISE_REDUCTION_TYPE);
|
|| (result |= expression==GEMV_T_TYPE && other==GEMV_N_TYPE);
|
||||||
break;
|
break;
|
||||||
case OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY:
|
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
|
||||||
result |= is_mvprod(expression)
|
result |= is_mvprod(expression)
|
||||||
|| expression==REDUCTION_TYPE;
|
|| expression==DOT_TYPE;
|
||||||
break;
|
break;
|
||||||
case OPERATOR_ROWS_REDUCTION_TYPE_FAMILY:
|
case OPERATOR_ROWS_DOT_TYPE_FAMILY:
|
||||||
result |= is_mmprod(expression)
|
result |= is_mmprod(expression)
|
||||||
|| is_mvprod(expression)
|
|| is_mvprod(expression)
|
||||||
|| expression==REDUCTION_TYPE;
|
|| expression==DOT_TYPE;
|
||||||
break;
|
break;
|
||||||
case OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY:
|
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
|
||||||
result |= is_mmprod(expression)
|
result |= is_mmprod(expression)
|
||||||
|| is_mvprod(expression)
|
|| is_mvprod(expression)
|
||||||
|| expression==REDUCTION_TYPE;
|
|| expression==DOT_TYPE;
|
||||||
break;
|
break;
|
||||||
case OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY:
|
case OPERATOR_GEMM_TYPE_FAMILY:
|
||||||
result |= (is_mmprod(expression) && !is_first)
|
result |= (is_mmprod(expression) && !is_first)
|
||||||
|| is_mvprod(expression)
|
|| is_mvprod(expression)
|
||||||
|| expression==REDUCTION_TYPE;
|
|| expression==DOT_TYPE;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
@@ -77,28 +77,28 @@ namespace isaac
|
|||||||
{
|
{
|
||||||
case OPERATOR_UNARY_TYPE_FAMILY:
|
case OPERATOR_UNARY_TYPE_FAMILY:
|
||||||
if(is_mmprod(left))
|
if(is_mmprod(left))
|
||||||
return MATRIX_AXPY_TYPE;
|
return GER_TYPE;
|
||||||
return left;
|
return left;
|
||||||
case OPERATOR_BINARY_TYPE_FAMILY:
|
case OPERATOR_BINARY_TYPE_FAMILY:
|
||||||
if(left == ROW_WISE_REDUCTION_TYPE || right == ROW_WISE_REDUCTION_TYPE) return ROW_WISE_REDUCTION_TYPE;
|
if(left == GEMV_N_TYPE || right == GEMV_N_TYPE) return GEMV_N_TYPE;
|
||||||
else if(left == COL_WISE_REDUCTION_TYPE || right == COL_WISE_REDUCTION_TYPE) return COL_WISE_REDUCTION_TYPE;
|
else if(left == GEMV_T_TYPE || right == GEMV_T_TYPE) return GEMV_T_TYPE;
|
||||||
else if(left == REDUCTION_TYPE || right == REDUCTION_TYPE) return REDUCTION_TYPE;
|
else if(left == DOT_TYPE || right == DOT_TYPE) return DOT_TYPE;
|
||||||
else if(left == VECTOR_AXPY_TYPE || right == VECTOR_AXPY_TYPE) return op.type==OPERATOR_OUTER_PROD_TYPE?MATRIX_AXPY_TYPE:VECTOR_AXPY_TYPE;
|
else if(left == AXPY_TYPE || right == AXPY_TYPE) return op.type==OPERATOR_OUTER_PROD_TYPE?GER_TYPE:AXPY_TYPE;
|
||||||
else if(left == MATRIX_AXPY_TYPE || right == MATRIX_AXPY_TYPE) return MATRIX_AXPY_TYPE;
|
else if(left == GER_TYPE || right == GER_TYPE) return GER_TYPE;
|
||||||
else if(is_mmprod(left) || is_mmprod(right)) return MATRIX_AXPY_TYPE;
|
else if(is_mmprod(left) || is_mmprod(right)) return GER_TYPE;
|
||||||
std::cout << left << " " << right << std::endl;
|
std::cout << left << " " << right << std::endl;
|
||||||
throw;
|
throw;
|
||||||
case OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY:
|
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
|
||||||
return REDUCTION_TYPE;
|
return DOT_TYPE;
|
||||||
case OPERATOR_ROWS_REDUCTION_TYPE_FAMILY:
|
case OPERATOR_ROWS_DOT_TYPE_FAMILY:
|
||||||
return ROW_WISE_REDUCTION_TYPE;
|
return GEMV_N_TYPE;
|
||||||
case OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY:
|
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
|
||||||
return COL_WISE_REDUCTION_TYPE;
|
return GEMV_T_TYPE;
|
||||||
case OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY:
|
case OPERATOR_GEMM_TYPE_FAMILY:
|
||||||
if(op.type==OPERATOR_MATRIX_PRODUCT_NN_TYPE) return MATRIX_PRODUCT_NN_TYPE;
|
if(op.type==OPERATOR_GEMM_NN_TYPE) return GEMM_NN_TYPE;
|
||||||
else if(op.type==OPERATOR_MATRIX_PRODUCT_TN_TYPE) return MATRIX_PRODUCT_TN_TYPE;
|
else if(op.type==OPERATOR_GEMM_TN_TYPE) return GEMM_TN_TYPE;
|
||||||
else if(op.type==OPERATOR_MATRIX_PRODUCT_NT_TYPE) return MATRIX_PRODUCT_NT_TYPE;
|
else if(op.type==OPERATOR_GEMM_NT_TYPE) return GEMM_NT_TYPE;
|
||||||
else return MATRIX_PRODUCT_TT_TYPE;
|
else return GEMM_TT_TYPE;
|
||||||
default:
|
default:
|
||||||
throw;
|
throw;
|
||||||
}
|
}
|
||||||
@@ -119,9 +119,9 @@ namespace isaac
|
|||||||
else if(node.lhs.subtype == DENSE_ARRAY_TYPE)
|
else if(node.lhs.subtype == DENSE_ARRAY_TYPE)
|
||||||
{
|
{
|
||||||
if(node.lhs.array->nshape()==1)
|
if(node.lhs.array->nshape()==1)
|
||||||
type_left = VECTOR_AXPY_TYPE;
|
type_left = AXPY_TYPE;
|
||||||
else
|
else
|
||||||
type_left = MATRIX_AXPY_TYPE;
|
type_left = GER_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
//Right
|
//Right
|
||||||
@@ -131,9 +131,9 @@ namespace isaac
|
|||||||
else if(node.rhs.subtype == DENSE_ARRAY_TYPE)
|
else if(node.rhs.subtype == DENSE_ARRAY_TYPE)
|
||||||
{
|
{
|
||||||
if(node.rhs.array->nshape()==1)
|
if(node.rhs.array->nshape()==1)
|
||||||
type_right = VECTOR_AXPY_TYPE;
|
type_right = AXPY_TYPE;
|
||||||
else
|
else
|
||||||
type_right = MATRIX_AXPY_TYPE;
|
type_right = GER_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -171,12 +171,10 @@ namespace isaac
|
|||||||
|
|
||||||
//Init
|
//Init
|
||||||
expression_type current_type;
|
expression_type current_type;
|
||||||
if(root_save.lhs.array->nshape()==0)
|
if(root_save.lhs.array->nshape()<=1)
|
||||||
current_type = SCALAR_AXPY_TYPE;
|
current_type=AXPY_TYPE;
|
||||||
else if(root_save.lhs.array->nshape()==1)
|
|
||||||
current_type=VECTOR_AXPY_TYPE;
|
|
||||||
else
|
else
|
||||||
current_type=MATRIX_AXPY_TYPE;
|
current_type=GER_TYPE;
|
||||||
final_type = current_type;
|
final_type = current_type;
|
||||||
|
|
||||||
/*----Parse required temporaries-----*/
|
/*----Parse required temporaries-----*/
|
||||||
@@ -193,18 +191,17 @@ namespace isaac
|
|||||||
//Creates temporary
|
//Creates temporary
|
||||||
tools::shared_ptr<array> tmp;
|
tools::shared_ptr<array> tmp;
|
||||||
switch(it->first){
|
switch(it->first){
|
||||||
case SCALAR_AXPY_TYPE:
|
case DOT_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
|
||||||
case REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
|
|
||||||
|
|
||||||
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
case AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||||
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
case GEMV_N_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||||
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
|
case GEMV_T_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||||
|
|
||||||
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
|
case GER_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||||
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
|
case GEMM_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
|
||||||
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
|
case GEMM_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
|
||||||
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
|
case GEMM_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
|
||||||
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
|
case GEMM_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
|
||||||
|
|
||||||
default: throw std::invalid_argument("Unrecognized operation");
|
default: throw std::invalid_argument("Unrecognized operation");
|
||||||
}
|
}
|
||||||
|
@@ -12,16 +12,16 @@ namespace preset
|
|||||||
void gemm::handle_node(array_expression::container_type &tree, size_t rootidx, args & a)
|
void gemm::handle_node(array_expression::container_type &tree, size_t rootidx, args & a)
|
||||||
{
|
{
|
||||||
//Matrix-Matrix product node
|
//Matrix-Matrix product node
|
||||||
if(tree[rootidx].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
|
if(tree[rootidx].op.type_family==OPERATOR_GEMM_TYPE_FAMILY)
|
||||||
{
|
{
|
||||||
if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY) a.A = &tree[rootidx].lhs;
|
if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY) a.A = &tree[rootidx].lhs;
|
||||||
if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs;
|
if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs;
|
||||||
switch(tree[rootidx].op.type)
|
switch(tree[rootidx].op.type)
|
||||||
{
|
{
|
||||||
case OPERATOR_MATRIX_PRODUCT_NN_TYPE: a.type = MATRIX_PRODUCT_NN_TYPE; break;
|
case OPERATOR_GEMM_NN_TYPE: a.type = GEMM_NN_TYPE; break;
|
||||||
case OPERATOR_MATRIX_PRODUCT_NT_TYPE: a.type = MATRIX_PRODUCT_NT_TYPE; break;
|
case OPERATOR_GEMM_NT_TYPE: a.type = GEMM_NT_TYPE; break;
|
||||||
case OPERATOR_MATRIX_PRODUCT_TN_TYPE: a.type = MATRIX_PRODUCT_TN_TYPE; break;
|
case OPERATOR_GEMM_TN_TYPE: a.type = GEMM_TN_TYPE; break;
|
||||||
case OPERATOR_MATRIX_PRODUCT_TT_TYPE: a.type = MATRIX_PRODUCT_TT_TYPE; break;
|
case OPERATOR_GEMM_TT_TYPE: a.type = GEMM_TT_TYPE; break;
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -31,7 +31,7 @@ void gemm::handle_node(array_expression::container_type &tree, size_t rootidx, a
|
|||||||
{
|
{
|
||||||
//alpha*PROD
|
//alpha*PROD
|
||||||
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY
|
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY
|
||||||
&& tree[tree[rootidx].rhs.node_index].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
|
&& tree[tree[rootidx].rhs.node_index].op.type_family==OPERATOR_GEMM_TYPE_FAMILY)
|
||||||
{
|
{
|
||||||
a.alpha = &tree[rootidx].lhs;
|
a.alpha = &tree[rootidx].lhs;
|
||||||
handle_node(tree, tree[rootidx].rhs.node_index, a);
|
handle_node(tree, tree[rootidx].rhs.node_index, a);
|
||||||
|
@@ -115,7 +115,7 @@ def main():
|
|||||||
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
|
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
|
||||||
|
|
||||||
#Source files
|
#Source files
|
||||||
src = 'src/lib/symbolic/preset.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/backend/templates/mreduction.cpp src/lib/backend/templates/reduction.cpp src/lib/backend/templates/mproduct.cpp src/lib/backend/templates/maxpy.cpp src/lib/backend/templates/base.cpp src/lib/backend/templates/vaxpy.cpp src/lib/backend/mapped_object.cpp src/lib/backend/stream.cpp src/lib/backend/parse.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/driver/backend.cpp src/lib/driver/device.cpp src/lib/driver/kernel.cpp src/lib/driver/buffer.cpp src/lib/driver/platform.cpp src/lib/driver/check.cpp src/lib/driver/program.cpp src/lib/driver/command_queue.cpp src/lib/driver/context.cpp src/lib/driver/event.cpp src/lib/driver/ndrange.cpp src/lib/driver/handle.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/wrap/clBLAS.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
|
src = 'src/lib/array.cpp src/lib/value_scalar.cpp src/lib/wrap/clBLAS.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/expression.cpp src/lib/symbolic/io.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/context.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/event.cpp src/lib/driver/device.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/backend/parse.cpp src/lib/backend/mapped_object.cpp src/lib/backend/templates/axpy.cpp src/lib/backend/templates/ger.cpp src/lib/backend/templates/gemm.cpp src/lib/backend/templates/dot.cpp src/lib/backend/templates/gemv.cpp src/lib/backend/templates/base.cpp src/lib/backend/stream.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
|
||||||
boostsrc = 'external/boost/libs/'
|
boostsrc = 'external/boost/libs/'
|
||||||
for s in ['numpy','python','smart_ptr','system','thread']:
|
for s in ['numpy','python','smart_ptr','system','thread']:
|
||||||
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]
|
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]
|
||||||
|
@@ -20,6 +20,6 @@ BOOST_PYTHON_MODULE(_isaac)
|
|||||||
|
|
||||||
export_driver();
|
export_driver();
|
||||||
export_exceptions();
|
export_exceptions();
|
||||||
export_model();
|
export_templates();
|
||||||
export_core();
|
export_core();
|
||||||
}
|
}
|
||||||
|
@@ -70,15 +70,15 @@ namespace tools
|
|||||||
else
|
else
|
||||||
name = bp::extract<std::string>(odtype.attr("__class__").attr("__name__"))();
|
name = bp::extract<std::string>(odtype.attr("__class__").attr("__name__"))();
|
||||||
|
|
||||||
if(name=="vaxpy") return isc::VECTOR_AXPY_TYPE;
|
if(name=="axpy") return isc::AXPY_TYPE;
|
||||||
else if(name=="maxpy") return isc::MATRIX_AXPY_TYPE;
|
else if(name=="ger") return isc::GER_TYPE;
|
||||||
else if(name=="reduction") return isc::REDUCTION_TYPE;
|
else if(name=="dot") return isc::DOT_TYPE;
|
||||||
else if(name=="mreduction_rows") return isc::ROW_WISE_REDUCTION_TYPE;
|
else if(name=="gemv_n") return isc::GEMV_N_TYPE;
|
||||||
else if(name=="mreduction_cols") return isc::COL_WISE_REDUCTION_TYPE;
|
else if(name=="gemm_t") return isc::GEMV_T_TYPE;
|
||||||
else if(name=="mproduct_nn") return isc::MATRIX_PRODUCT_NN_TYPE;
|
else if(name=="gemm_nn") return isc::GEMM_NN_TYPE;
|
||||||
else if(name=="mproduct_tn") return isc::MATRIX_PRODUCT_TN_TYPE;
|
else if(name=="gemm_tn") return isc::GEMM_TN_TYPE;
|
||||||
else if(name=="mproduct_nt") return isc::MATRIX_PRODUCT_NT_TYPE;
|
else if(name=="gemm_nt") return isc::GEMM_NT_TYPE;
|
||||||
else if(name=="mproduct_tt") return isc::MATRIX_PRODUCT_TT_TYPE;
|
else if(name=="gemm_tt") return isc::GEMM_TT_TYPE;
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
PyErr_SetString(PyExc_TypeError, "Template type not understood");
|
PyErr_SetString(PyExc_TypeError, "Template type not understood");
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
#include "isaac/model/model.h"
|
||||||
#include "common.hpp"
|
#include "common.hpp"
|
||||||
#include "core.h"
|
#include "core.h"
|
||||||
|
|
||||||
@@ -79,6 +80,11 @@ unsigned int size(datatype<T> const & dt)
|
|||||||
|
|
||||||
namespace detail
|
namespace detail
|
||||||
{
|
{
|
||||||
|
std::shared_ptr<isc::model> construct_model(bp::object dtype, bp::object const & tp, isc::driver::CommandQueue & queue)
|
||||||
|
{
|
||||||
|
return std::shared_ptr<isc::model>(new isc::model(tools::extract_template_type(tp), tools::extract_dtype(dtype), (isaac::templates::base const &)bp::extract<isaac::templates::base>(tp), queue));
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<isc::array>
|
std::shared_ptr<isc::array>
|
||||||
ndarray_to_iscarray(const np::ndarray& array, const isc::driver::Context& ctx)
|
ndarray_to_iscarray(const np::ndarray& array, const isc::driver::Context& ctx)
|
||||||
{
|
{
|
||||||
@@ -172,7 +178,6 @@ namespace detail
|
|||||||
bp::throw_error_already_set();
|
bp::throw_error_already_set();
|
||||||
throw;
|
throw;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,6 +188,10 @@ namespace detail
|
|||||||
void export_core()
|
void export_core()
|
||||||
{
|
{
|
||||||
|
|
||||||
|
bp::class_<isaac::model>("model", bp::no_init)
|
||||||
|
.def("__init__", bp::make_constructor(detail::construct_model))
|
||||||
|
.def("execute", &isc::model::execute);
|
||||||
|
|
||||||
bp::class_<isc::value_scalar>("value_scalar", bp::no_init)
|
bp::class_<isc::value_scalar>("value_scalar", bp::no_init)
|
||||||
.add_property("dtype", &isc::value_scalar::dtype);
|
.add_property("dtype", &isc::value_scalar::dtype);
|
||||||
|
|
||||||
@@ -206,15 +215,15 @@ void export_core()
|
|||||||
#undef INSTANTIATE
|
#undef INSTANTIATE
|
||||||
|
|
||||||
bp::enum_<isc::expression_type>("operations")
|
bp::enum_<isc::expression_type>("operations")
|
||||||
MAP_ENUM(VECTOR_AXPY_TYPE, isc)
|
MAP_ENUM(AXPY_TYPE, isc)
|
||||||
MAP_ENUM(MATRIX_AXPY_TYPE, isc)
|
MAP_ENUM(GER_TYPE, isc)
|
||||||
MAP_ENUM(REDUCTION_TYPE, isc)
|
MAP_ENUM(DOT_TYPE, isc)
|
||||||
MAP_ENUM(ROW_WISE_REDUCTION_TYPE, isc)
|
MAP_ENUM(GEMV_N_TYPE, isc)
|
||||||
MAP_ENUM(COL_WISE_REDUCTION_TYPE, isc)
|
MAP_ENUM(GEMV_T_TYPE, isc)
|
||||||
MAP_ENUM(VECTOR_AXPY_TYPE, isc)
|
MAP_ENUM(GEMM_NN_TYPE, isc)
|
||||||
MAP_ENUM(VECTOR_AXPY_TYPE, isc)
|
MAP_ENUM(GEMM_TN_TYPE, isc)
|
||||||
MAP_ENUM(VECTOR_AXPY_TYPE, isc)
|
MAP_ENUM(GEMM_NT_TYPE, isc)
|
||||||
MAP_ENUM(VECTOR_AXPY_TYPE, isc);
|
MAP_ENUM(GEMM_TT_TYPE, isc);
|
||||||
|
|
||||||
#define ADD_SCALAR_HANDLING(OP)\
|
#define ADD_SCALAR_HANDLING(OP)\
|
||||||
.def(bp::self OP int())\
|
.def(bp::self OP int())\
|
||||||
|
@@ -1,70 +1,71 @@
|
|||||||
#include "isaac/backend/templates/vaxpy.h"
|
#include "isaac/backend/templates/axpy.h"
|
||||||
#include "isaac/backend/templates/maxpy.h"
|
#include "isaac/backend/templates/ger.h"
|
||||||
#include "isaac/backend/templates/reduction.h"
|
#include "isaac/backend/templates/dot.h"
|
||||||
#include "isaac/backend/templates/mreduction.h"
|
#include "isaac/backend/templates/gemv.h"
|
||||||
#include "isaac/backend/templates/mproduct.h"
|
#include "isaac/backend/templates/gemm.h"
|
||||||
#include "isaac/model/model.h"
|
#include "isaac/model/model.h"
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "common.hpp"
|
||||||
#include "model.h"
|
#include "model.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace tpt = isaac::templates;
|
||||||
|
|
||||||
|
|
||||||
namespace detail
|
namespace detail
|
||||||
{
|
{
|
||||||
bp::list input_sizes(isaac::base & temp, isc::expressions_tuple const & tree)
|
bp::list input_sizes(tpt::base & temp, isc::expressions_tuple const & tree)
|
||||||
{
|
{
|
||||||
std::vector<int> tmp = temp.input_sizes(tree);
|
std::vector<int> tmp = temp.input_sizes(tree);
|
||||||
return tools::to_list(tmp.begin(), tmp.end());
|
return tools::to_list(tmp.begin(), tmp.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<isc::model> construct_model(bp::object dtype, bp::object const & tp, isc::driver::CommandQueue & queue)
|
|
||||||
{
|
|
||||||
return std::shared_ptr<isc::model>(new isc::model(tools::extract_template_type(tp), tools::extract_dtype(dtype), (isc::base const &)bp::extract<isc::base>(tp), queue));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void export_model()
|
void export_templates()
|
||||||
{
|
{
|
||||||
|
|
||||||
bp::class_<isaac::model>("model", bp::no_init)
|
bp::object templates_module(bp::handle<>(bp::borrowed(PyImport_AddModule("isaac.templates"))));
|
||||||
.def("__init__", bp::make_constructor(detail::construct_model))
|
bp::scope().attr("templates") = templates_module;
|
||||||
.def("execute", &isc::model::execute);
|
bp::scope template_scope = templates_module;
|
||||||
|
|
||||||
bp::enum_<isaac::fetching_policy_type>
|
|
||||||
|
bp::enum_<tpt::fetching_policy_type>
|
||||||
("fetching_policy_type")
|
("fetching_policy_type")
|
||||||
.value("FETCH_FROM_LOCAL", isc::FETCH_FROM_LOCAL)
|
.value("FETCH_FROM_LOCAL", tpt::FETCH_FROM_LOCAL)
|
||||||
.value("FETCH_FROM_GLOBAL_STRIDED", isc::FETCH_FROM_GLOBAL_STRIDED)
|
.value("FETCH_FROM_GLOBAL_STRIDED", tpt::FETCH_FROM_GLOBAL_STRIDED)
|
||||||
.value("FETCH_FROM_GLOBAL_CONTIGUOUS", isc::FETCH_FROM_GLOBAL_CONTIGUOUS)
|
.value("FETCH_FROM_GLOBAL_CONTIGUOUS", tpt::FETCH_FROM_GLOBAL_CONTIGUOUS)
|
||||||
;
|
;
|
||||||
|
|
||||||
//Base
|
//Base
|
||||||
{
|
{
|
||||||
#define __PROP(name) .def_readonly(#name, &isaac::base::parameters_type::name)
|
#define __PROP(name) .def_readonly(#name, &tpt::base::parameters_type::name)
|
||||||
bp::class_<isaac::base, boost::noncopyable>("base", bp::no_init)
|
bp::class_<tpt::base, boost::noncopyable>("base", bp::no_init)
|
||||||
.def("lmem_usage", &isaac::base::lmem_usage)
|
.def("lmem_usage", &tpt::base::lmem_usage)
|
||||||
.def("registers_usage", &isaac::base::registers_usage)
|
.def("registers_usage", &tpt::base::registers_usage)
|
||||||
.def("is_invalid", &isaac::base::is_invalid)
|
.def("is_invalid", &tpt::base::is_invalid)
|
||||||
.def("input_sizes", &detail::input_sizes)
|
.def("input_sizes", &detail::input_sizes)
|
||||||
;
|
;
|
||||||
#undef __PROP
|
#undef __PROP
|
||||||
}
|
}
|
||||||
|
|
||||||
#define WRAP_BASE(name) bp::class_<isaac::base_impl<isaac::name, isaac::name::parameters_type>, bp::bases<isaac::base>, boost::noncopyable>(#name, bp::no_init);
|
#define WRAP_BASE(name) bp::class_<tpt::base_impl<tpt::name, tpt::name::parameters_type>, bp::bases<tpt::base>, boost::noncopyable>(#name, bp::no_init);
|
||||||
#define WRAP_TEMPLATE(name, basename, ...) bp::class_<isaac::name, bp::bases<isaac::base_impl<isaac::basename, isaac::basename::parameters_type> > >(#name, bp::init<__VA_ARGS__>())\
|
#define WRAP_TEMPLATE(name, basename, ...) bp::class_<tpt::name, bp::bases<tpt::base_impl<tpt::basename, tpt::basename::parameters_type> > >(#name, bp::init<__VA_ARGS__>())\
|
||||||
.add_property("local_size_0", &isc::name::local_size_0)\
|
.add_property("local_size_0", &tpt::name::local_size_0)\
|
||||||
.add_property("local_size_1", &isc::name::local_size_1);
|
.add_property("local_size_1", &tpt::name::local_size_1);
|
||||||
#define WRAP_SINGLE_TEMPLATE(name, ...) WRAP_BASE(name) WRAP_TEMPLATE(name, name, __VA_ARGS__)
|
#define WRAP_SINGLE_TEMPLATE(name, ...) WRAP_BASE(name) WRAP_TEMPLATE(name, name, __VA_ARGS__)
|
||||||
|
|
||||||
//Vector AXPY
|
//Vector AXPY
|
||||||
WRAP_SINGLE_TEMPLATE(vaxpy, uint, uint, uint, isaac::fetching_policy_type)
|
WRAP_SINGLE_TEMPLATE(axpy, uint, uint, uint, tpt::fetching_policy_type)
|
||||||
WRAP_SINGLE_TEMPLATE(maxpy, uint, uint, uint, uint, uint, isaac::fetching_policy_type)
|
WRAP_SINGLE_TEMPLATE(ger, uint, uint, uint, uint, uint, tpt::fetching_policy_type)
|
||||||
WRAP_SINGLE_TEMPLATE(reduction, uint, uint, uint, isaac::fetching_policy_type)
|
WRAP_SINGLE_TEMPLATE(dot, uint, uint, uint, tpt::fetching_policy_type)
|
||||||
WRAP_BASE(mreduction)
|
WRAP_BASE(gemv)
|
||||||
WRAP_TEMPLATE(mreduction_rows, mreduction, uint, uint, uint, uint, uint, isaac::fetching_policy_type)
|
WRAP_TEMPLATE(gemv_n, gemv, uint, uint, uint, uint, uint, tpt::fetching_policy_type)
|
||||||
WRAP_TEMPLATE(mreduction_cols, mreduction, uint, uint, uint, uint, uint, isaac::fetching_policy_type)
|
WRAP_TEMPLATE(gemv_t, gemv, uint, uint, uint, uint, uint, tpt::fetching_policy_type)
|
||||||
WRAP_BASE(mproduct)
|
WRAP_BASE(gemm)
|
||||||
WRAP_TEMPLATE(mproduct_nn, mproduct, uint, uint, uint, uint, uint, uint, uint, uint, isaac::fetching_policy_type, isaac::fetching_policy_type, uint, uint)
|
WRAP_TEMPLATE(gemm_nn, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint)
|
||||||
WRAP_TEMPLATE(mproduct_tn, mproduct, uint, uint, uint, uint, uint, uint, uint, uint, isaac::fetching_policy_type, isaac::fetching_policy_type, uint, uint)
|
WRAP_TEMPLATE(gemm_tn, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint)
|
||||||
WRAP_TEMPLATE(mproduct_nt, mproduct, uint, uint, uint, uint, uint, uint, uint, uint, isaac::fetching_policy_type, isaac::fetching_policy_type, uint, uint)
|
WRAP_TEMPLATE(gemm_nt, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint)
|
||||||
WRAP_TEMPLATE(mproduct_tt, mproduct, uint, uint, uint, uint, uint, uint, uint, uint, isaac::fetching_policy_type, isaac::fetching_policy_type, uint, uint)
|
WRAP_TEMPLATE(gemm_tt, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint)
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
#ifndef ISAAC_PYTHON_MODEL_HPP
|
#ifndef ISAAC_PYTHON_MODEL_HPP
|
||||||
#define ISAAC_PYTHON_MODEL_HPP
|
#define ISAAC_PYTHON_MODEL_HPP
|
||||||
|
|
||||||
void export_model();
|
void export_templates();
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -4,7 +4,7 @@ if(CUDA_FOUND)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
get_property(ISAAC_PATH TARGET isaac PROPERTY LOCATION)
|
get_property(ISAAC_PATH TARGET isaac PROPERTY LOCATION)
|
||||||
foreach(PROG maxpy vaxpy reduction mreduction mproduct)
|
foreach(PROG axpy dot ger gemv gemm)
|
||||||
add_executable(${PROG}-test ${PROG}.cpp)
|
add_executable(${PROG}-test ${PROG}.cpp)
|
||||||
add_test(${PROG} ${PROG}-test)
|
add_test(${PROG} ${PROG}-test)
|
||||||
target_link_libraries(${PROG}-test isaac ${OPENCL_LIBRARIES})
|
target_link_libraries(${PROG}-test isaac ${OPENCL_LIBRARIES})
|
||||||
|
@@ -14,10 +14,10 @@ from numpy import cumsum
|
|||||||
|
|
||||||
import tools
|
import tools
|
||||||
|
|
||||||
fetch_types = [isc.fetching_policy_type.FETCH_FROM_LOCAL,
|
fetch_types = [isc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_CONTIGUOUS,
|
||||||
isc.fetching_policy_type.FETCH_FROM_LOCAL,
|
isc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_STRIDED,
|
||||||
isc.fetching_policy_type.FETCH_FROM_LOCAL,
|
isc.templates.fetching_policy_type.FETCH_FROM_LOCAL,
|
||||||
isc.fetching_policy_type.FETCH_FROM_LOCAL]
|
isc.templates.fetching_policy_type.FETCH_FROM_LOCAL]
|
||||||
|
|
||||||
def exhaustive(template, sizes, context):
|
def exhaustive(template, sizes, context):
|
||||||
tree, _ = tools.tree_of(template, sizes, context)
|
tree, _ = tools.tree_of(template, sizes, context)
|
||||||
@@ -159,15 +159,15 @@ def is_local_optimum(parameters, template, sizes, context):
|
|||||||
tree, _ = tools.tree_of(template, sizes, context)
|
tree, _ = tools.tree_of(template, sizes, context)
|
||||||
genetic_infos = tools.genetic_infos_of(template)
|
genetic_infos = tools.genetic_infos_of(template)
|
||||||
|
|
||||||
if issubclass(template, isc.vaxpy):
|
if issubclass(template, isc.axpy):
|
||||||
sweep_over = [0,1,2]
|
sweep_over = [0,1,2]
|
||||||
elif issubclass(template, isc.reduction):
|
elif issubclass(template, isc.dot):
|
||||||
sweep_over = [0,1,2]
|
sweep_over = [0,1,2]
|
||||||
elif issubclass(template, isc.maxpy):
|
elif issubclass(template, isc.ger):
|
||||||
sweep_over = [0,1,2,3,4]
|
sweep_over = [0,1,2,3,4]
|
||||||
elif issubclass(template, isc.mreduction):
|
elif issubclass(template, isc.gemv):
|
||||||
sweep_over = [0,1,2,3,4]
|
sweep_over = [0,1,2,3,4]
|
||||||
elif issubclass(template, isc.mproduct):
|
elif issubclass(template, isc.gemm):
|
||||||
sweep_over = [1,2,3,4,5,7,10,11]
|
sweep_over = [1,2,3,4,5,7,10,11]
|
||||||
|
|
||||||
#Evaluate the provided parameters guess
|
#Evaluate the provided parameters guess
|
||||||
|
@@ -36,30 +36,30 @@ def benchmark(template, setting, tree):
|
|||||||
|
|
||||||
|
|
||||||
def tree_of(template, sizes, context):
|
def tree_of(template, sizes, context):
|
||||||
if issubclass(template, isc.vaxpy):
|
if issubclass(template, isc.templates.axpy):
|
||||||
N, = sizes
|
N, = sizes
|
||||||
x = isc.empty(N, dtype=isc.float32, context=context)
|
x = isc.empty(N, dtype=isc.float32, context=context)
|
||||||
y = isc.empty(N, dtype=isc.float32, context=context)
|
y = isc.empty(N, dtype=isc.float32, context=context)
|
||||||
return x + y, (x, y)
|
return x + y, (x, y)
|
||||||
elif issubclass(template, isc.reduction):
|
elif issubclass(template, isc.templates.dot):
|
||||||
N, = sizes
|
N, = sizes
|
||||||
x = isc.empty(N, context=context)
|
x = isc.empty(N, context=context)
|
||||||
y = isc.empty(N, context=context)
|
y = isc.empty(N, context=context)
|
||||||
return isc.dot(x, y), (x, y)
|
return isc.dot(x, y), (x, y)
|
||||||
elif issubclass(template, isc.maxpy):
|
elif issubclass(template, isc.templates.ger):
|
||||||
M, N = sizes
|
M, N = sizes
|
||||||
A = isc.empty((M,N), context=context)
|
A = isc.empty((M,N), context=context)
|
||||||
B = isc.empty((M,N), context=context)
|
B = isc.empty((M,N), context=context)
|
||||||
return A + B, (A, B)
|
return A + B, (A, B)
|
||||||
elif issubclass(template, isc.mreduction):
|
elif issubclass(template, isc.templates.gemv):
|
||||||
T = template is isc.mreduction_cols
|
T = template is isc.templates.gemv_t
|
||||||
M, N = sizes[::-1] if T else sizes
|
M, N = sizes[::-1] if T else sizes
|
||||||
A = isc.empty((M,N), context=context)
|
A = isc.empty((M,N), context=context)
|
||||||
x = isc.empty(N, context=context)
|
x = isc.empty(N, context=context)
|
||||||
return isc.dot(A.T, x) if T else isc.dot(A, x), (A, x)
|
return isc.dot(A.T, x) if T else isc.dot(A, x), (A, x)
|
||||||
elif issubclass(template, isc.mproduct):
|
elif issubclass(template, isc.templates.gemm):
|
||||||
AT = template is isc.mproduct_tn or template is isc.mproduct_tt
|
AT = template is isc.templates.gemm_tn or template is isc.templates.gemm_tt
|
||||||
BT = template is isc.mproduct_nt or template is isc.mproduct_tt
|
BT = template is isc.templates.gemm_nt or template is isc.templates.gemm_tt
|
||||||
M, N, K = sizes
|
M, N, K = sizes
|
||||||
A = isc.empty((K, M) if AT else (M, K), context=context)
|
A = isc.empty((K, M) if AT else (M, K), context=context)
|
||||||
B = isc.empty((N, K) if BT else (K, N), context=context)
|
B = isc.empty((N, K) if BT else (K, N), context=context)
|
||||||
@@ -68,35 +68,35 @@ def tree_of(template, sizes, context):
|
|||||||
return isc.dot(AA, BB), (A, B)
|
return isc.dot(AA, BB), (A, B)
|
||||||
|
|
||||||
def memory_footprint(template, sizes):
|
def memory_footprint(template, sizes):
|
||||||
if issubclass(template, isc.vaxpy):
|
if issubclass(template, isc.templates.axpy):
|
||||||
return 4*3*sizes[0]*1e-9
|
return 4*3*sizes[0]*1e-9
|
||||||
elif issubclass(template, isc.reduction):
|
elif issubclass(template, isc.templates.dot):
|
||||||
return 4*2*sizes[0]*1e-9
|
return 4*2*sizes[0]*1e-9
|
||||||
elif issubclass(template, isc.maxpy):
|
elif issubclass(template, isc.templates.ger):
|
||||||
return 4*3*sizes[0]*sizes[1]*1e-9
|
return 4*3*sizes[0]*sizes[1]*1e-9
|
||||||
elif issubclass(template, isc.mreduction):
|
elif issubclass(template, isc.templates.gemv):
|
||||||
return 4*sizes[0]*sizes[1]*1e-9
|
return 4*sizes[0]*sizes[1]*1e-9
|
||||||
elif issubclass(template, isc.mproduct):
|
elif issubclass(template, isc.templates.gemm):
|
||||||
return 4*(sizes[0]*sizes[1] + sizes[0]*sizes[2] + sizes[1]*sizes[2])*1e-9
|
return 4*(sizes[0]*sizes[1] + sizes[0]*sizes[2] + sizes[1]*sizes[2])*1e-9
|
||||||
|
|
||||||
def metric_of(template):
|
def metric_of(template):
|
||||||
memory_bound = [isc.vaxpy, isc.reduction, isc.maxpy, isc.mreduction]
|
memory_bound = [isc.templates.axpy, isc.templates.dot, isc.templates.ger, isc.templates.gemv]
|
||||||
compute_bound = [isc.mproduct]
|
compute_bound = [isc.templates.gemm]
|
||||||
if any([issubclass(template, x) for x in memory_bound]):
|
if any([issubclass(template, x) for x in memory_bound]):
|
||||||
return lambda sizes, t: memory_footprint(template, sizes)/t
|
return lambda sizes, t: memory_footprint(template, sizes)/t
|
||||||
elif any([issubclass(template, x) for x in compute_bound]):
|
elif any([issubclass(template, x) for x in compute_bound]):
|
||||||
return lambda sizes, t: 2*sizes[0]*sizes[1]*sizes[2]*1e-9/t
|
return lambda sizes, t: 2*sizes[0]*sizes[1]*sizes[2]*1e-9/t
|
||||||
|
|
||||||
def genetic_infos_of(template):
|
def genetic_infos_of(template):
|
||||||
if issubclass(template, isc.vaxpy):
|
if issubclass(template, isc.templates.axpy):
|
||||||
return {'categorical': [3], 'nbits': [3,4,4,2] }
|
return {'categorical': [3], 'nbits': [3,4,4,2] }
|
||||||
elif issubclass(template, isc.reduction):
|
elif issubclass(template, isc.templates.dot):
|
||||||
return {'categorical': [3], 'nbits':[3,4,4,2]}
|
return {'categorical': [3], 'nbits':[3,4,4,2]}
|
||||||
elif issubclass(template, isc.maxpy):
|
elif issubclass(template, isc.templates.ger):
|
||||||
return {'categorical': [5], 'nbits': [3,3,3,3,4,2]}
|
return {'categorical': [5], 'nbits': [3,3,3,3,4,2]}
|
||||||
elif issubclass(template, isc.mreduction):
|
elif issubclass(template, isc.templates.gemv):
|
||||||
return {'categorical': [5], 'nbits': [3,3,3,3,4,2]}
|
return {'categorical': [5], 'nbits': [3,3,3,3,4,2]}
|
||||||
elif issubclass(template, isc.mproduct):
|
elif issubclass(template, isc.templates.gemm):
|
||||||
return {'categorical': [8,9], 'nbits': [3,3,3,3,3,2,2,2,2,2,3,3]}
|
return {'categorical': [8,9], 'nbits': [3,3,3,3,3,2,2,2,2,2,3,3]}
|
||||||
|
|
||||||
|
|
||||||
|
22
tune/tune.py
22
tune/tune.py
@@ -23,14 +23,14 @@ def tune(device, operation, json_path):
|
|||||||
|
|
||||||
#List of size tuples to use
|
#List of size tuples to use
|
||||||
sizes = {}
|
sizes = {}
|
||||||
sizes[isc.vaxpy] = [(x,) for x in tools.expspace(1e3, 1e7, 4)]
|
sizes[isc.templates.axpy] = [(x,) for x in tools.expspace(1e3, 1e7, 4)]
|
||||||
sizes[isc.mreduction_rows] = product(pow2range(4,17), pow2range(4,17))
|
sizes[isc.templates.gemv_n] = product(pow2range(4,17), pow2range(4,17))
|
||||||
sizes[isc.mreduction_cols] = isc.mreduction_rows
|
sizes[isc.templates.gemv_t] = sizes[isc.templates.gemv_n]
|
||||||
sizes[isc.mproduct_nn] = product(pow2range(5, 10), pow2range(5, 10), pow2range(5, 10))
|
sizes[isc.templates.gemm_nn] = product(pow2range(5, 10), pow2range(5, 10), pow2range(5, 10))
|
||||||
sizes[isc.mproduct_nn] = [(128, 169, 1728)]
|
sizes[isc.templates.gemm_nn] = [(128, 169, 1728)]
|
||||||
sizes[isc.mproduct_tn] = sizes[isc.mproduct_nn]
|
sizes[isc.templates.gemm_tn] = sizes[isc.templates.gemm_nn]
|
||||||
sizes[isc.mproduct_nt] = sizes[isc.mproduct_nn]
|
sizes[isc.templates.gemm_nt] = sizes[isc.templates.gemm_nn]
|
||||||
sizes[isc.mproduct_tt] = sizes[isc.mproduct_nn]
|
sizes[isc.templates.gemm_tt] = sizes[isc.templates.gemm_nn]
|
||||||
sizes = unique(list(sizes[operation]))
|
sizes = unique(list(sizes[operation]))
|
||||||
sizes = [x for x in sizes if 1e-4 <= tools.memory_footprint(operation, x) <= 1e-1]
|
sizes = [x for x in sizes if 1e-4 <= tools.memory_footprint(operation, x) <= 1e-1]
|
||||||
|
|
||||||
@@ -123,9 +123,9 @@ def parse_arguments():
|
|||||||
print("----------------")
|
print("----------------")
|
||||||
|
|
||||||
|
|
||||||
operation = {'vaxpy': isc.vaxpy, 'dot': isc.reduction,
|
operation = {'axpy': isc.templates.axpy, 'dot': isc.templates.dot,
|
||||||
'maxpy': isc.maxpy, 'gemv_n': isc.mreduction_rows, 'gemv_t': isc.mreduction_cols,
|
'ger': isc.templates.ger, 'gemv_n': isc.templates.gemv_n, 'gemv_t': isc.templates.gemv_t,
|
||||||
'gemm_nn': isc.mproduct_nn, 'gemm_tn': isc.mproduct_tn, 'gemm_nt': isc.mproduct_nt, 'gemm_tt':isc.mproduct_tt}[args.operation]
|
'gemm_nn': isc.templates.gemm_nn, 'gemm_tn': isc.templates.gemm_tn, 'gemm_nt': isc.templates.gemm_nt, 'gemm_tt':isc.templates.gemm_tt}[args.operation]
|
||||||
if not args.json:
|
if not args.json:
|
||||||
json = tools.sanitize(device.name) + '.json'
|
json = tools.sanitize(device.name) + '.json'
|
||||||
return (device, operation, json)
|
return (device, operation, json)
|
||||||
|
Reference in New Issue
Block a user