Cleaning: Largely renamed templates to BLAS-like names

This commit is contained in:
Philippe Tillet
2015-07-11 09:36:01 -04:00
parent 281fa9c7a6
commit cfa6ea812d
40 changed files with 606 additions and 572 deletions

View File

@@ -219,15 +219,15 @@ array_expression norm(array_expression const &, unsigned int order = 2);
array_expression repmat(array const &, int_t const & rep1, int_t const & rep2);
#define ISAAC_DECLARE_REDUCTION(OPNAME) \
#define ISAAC_DECLARE_DOT(OPNAME) \
array_expression OPNAME(array const & M, int_t axis = -1);\
array_expression OPNAME(array_expression const & M, int_t axis = -1);
ISAAC_DECLARE_REDUCTION(sum)
ISAAC_DECLARE_REDUCTION(argmax)
ISAAC_DECLARE_REDUCTION(max)
ISAAC_DECLARE_REDUCTION(min)
ISAAC_DECLARE_REDUCTION(argmin)
ISAAC_DECLARE_DOT(sum)
ISAAC_DECLARE_DOT(argmax)
ISAAC_DECLARE_DOT(max)
ISAAC_DECLARE_DOT(min)
ISAAC_DECLARE_DOT(argmin)
array_expression eye(std::size_t, std::size_t, isaac::numeric_type, driver::Context context = driver::queues.default_context());
array_expression zeros(std::size_t M, std::size_t N, numeric_type dtype, driver::Context context = driver::queues.default_context());

View File

@@ -84,46 +84,46 @@ protected:
*
* Maps prod(matrix_expression, matrix_expression)
*/
class mapped_mproduct : public mapped_object, public binary_leaf
class mapped_gemm : public mapped_object, public binary_leaf
{
public:
mapped_mproduct(std::string const & scalartype, unsigned int id, node_info info);
mapped_gemm(std::string const & scalartype, unsigned int id, node_info info);
};
/** @brief Reduction
*
* Base class for mapping a reduction
* Base class for mapping a dot
*/
class mapped_reduction : public mapped_object, public binary_leaf
class mapped_dot : public mapped_object, public binary_leaf
{
public:
mapped_reduction(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key);
mapped_dot(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key);
int_t root_idx() const;
isaac::array_expression const & array_expression() const;
array_expression::node root_node() const;
bool is_index_reduction() const;
bool is_index_dot() const;
op_element root_op() const;
};
/** @brief Scalar reduction
/** @brief Scalar dot
*
* Maps a scalar reduction (max, min, argmax, inner_prod, etc..)
* Maps a scalar dot (max, min, argmax, inner_prod, etc..)
*/
class mapped_scalar_reduction : public mapped_reduction
class mapped_scalar_dot : public mapped_dot
{
public:
mapped_scalar_reduction(std::string const & scalartype, unsigned int id, node_info info);
mapped_scalar_dot(std::string const & scalartype, unsigned int id, node_info info);
};
/** @brief Vector reduction
/** @brief Vector dot
*
* Maps a row-wise reduction (max, min, argmax, matrix-vector product, etc..)
* Maps a row-wise dot (max, min, argmax, matrix-vector product, etc..)
*/
class mapped_mreduction : public mapped_reduction
class mapped_gemv : public mapped_dot
{
public:
mapped_mreduction(std::string const & scalartype, unsigned int id, node_info info);
mapped_gemv(std::string const & scalartype, unsigned int id, node_info info);
};
/** @brief Host scalar

View File

@@ -13,8 +13,8 @@ namespace detail
{
bool is_node_leaf(op_element const & op);
bool is_scalar_reduction(array_expression::node const & node);
bool is_vector_reduction(array_expression::node const & node);
bool is_scalar_dot(array_expression::node const & node);
bool is_vector_dot(array_expression::node const & node);
bool is_assignment(op_element const & op);
bool is_elementwise_operator(op_element const & op);
bool is_elementwise_function(op_element const & op);

View File

@@ -5,27 +5,30 @@
namespace isaac
{
namespace templates
{
class vaxpy_parameters : public base::parameters_type
class axpy_parameters : public base::parameters_type
{
public:
vaxpy_parameters(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy);
axpy_parameters(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy);
unsigned int num_groups;
fetching_policy_type fetching_policy;
};
class vaxpy : public base_impl<vaxpy, vaxpy_parameters>
class axpy : public base_impl<axpy, axpy_parameters>
{
private:
virtual int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const;
public:
vaxpy(vaxpy::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
vaxpy(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
axpy(axpy::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
axpy(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
};
}
}
#endif

View File

@@ -14,6 +14,9 @@
namespace isaac
{
namespace templates
{
enum fetching_policy_type
{
FETCH_FROM_LOCAL,
@@ -147,8 +150,8 @@ protected:
}
}
static void compute_reduction(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op);
static void compute_index_reduction(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op);
static void compute_dot(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op);
static void compute_index_dot(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op);
static void process_all(std::string const & type_key, std::string const & str, kernel_generation_stream & stream, std::vector<mapping_type> const & mappings);
static void process_all_at(std::string const & type_key, std::string const & str, kernel_generation_stream & stream, std::vector<mapping_type> const & mappings, size_t root_idx, leaf_t leaf);
static std::string neutral_element(op_element const & op, driver::backend_type backend, std::string const & datatype);
@@ -159,8 +162,8 @@ protected:
static bool is_strided(array_expression::node const & node);
static int_t vector_size(array_expression::node const & node);
static std::pair<int_t, int_t> matrix_size(array_expression::node const & node);
static bool is_reduction(array_expression::node const & node);
static bool is_index_reduction(op_element const & op);
static bool is_dot(array_expression::node const & node);
static bool is_index_dot(op_element const & op);
static std::string access_vector_type(std::string const & v, int i);
tools::shared_ptr<symbolic_binder> make_binder();
@@ -204,6 +207,7 @@ protected:
binding_policy_t binding_policy_;
};
}
}
#endif

View File

@@ -1,32 +1,34 @@
#ifndef ISAAC_BACKEND_TEMPLATES_REDUCTION_H
#define ISAAC_BACKEND_TEMPLATES_REDUCTION_H
#ifndef ISAAC_BACKEND_TEMPLATES_DOT_H
#define ISAAC_BACKEND_TEMPLATES_DOT_H
#include "isaac/backend/templates/base.h"
namespace isaac
{
struct reduction_parameters : public base::parameters_type
namespace templates
{
reduction_parameters(unsigned int _simd_width,
struct dot_parameters : public base::parameters_type
{
dot_parameters(unsigned int _simd_width,
unsigned int _group_size, unsigned int _num_groups,
fetching_policy_type _fetching_policy);
unsigned int num_groups;
fetching_policy_type fetching_policy;
};
class reduction : public base_impl<reduction, reduction_parameters>
class dot : public base_impl<dot, dot_parameters>
{
private:
unsigned int lmem_usage(expressions_tuple const & expressions) const;
int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_reduction*> exprs,
inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_dot*> exprs,
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const;
std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const;
public:
reduction(reduction::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
reduction(unsigned int simd, unsigned int ls, unsigned int ng, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
dot(dot::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
dot(unsigned int simd, unsigned int ls, unsigned int ng, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
private:
@@ -34,6 +36,7 @@ private:
std::vector< driver::Buffer > tmpidx_;
};
}
}
#endif

View File

@@ -7,12 +7,13 @@
namespace isaac
{
namespace templates
{
class model;
struct mproduct_parameters : public base::parameters_type
struct gemm_parameters : public base::parameters_type
{
mproduct_parameters(unsigned int simd_width
gemm_parameters(unsigned int simd_width
, int_t local_size_0, int_t KL, int_t local_size_1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type A_fetching_policy, fetching_policy_type B_fetching_policy
@@ -38,7 +39,7 @@ struct mproduct_parameters : public base::parameters_type
bool unroll_outer;
};
class mproduct : public base_impl<mproduct, mproduct_parameters>
class gemm : public base_impl<gemm, gemm_parameters>
{
private:
unsigned int lmem_usage(expressions_tuple const & expressions) const;
@@ -50,7 +51,7 @@ private:
array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap);
std::vector<int_t> infos(expressions_tuple const & expressions, isaac::symbolic::preset::gemm::args &arguments) const;
public:
mproduct(mproduct::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans);
gemm(gemm::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans);
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
void cleanup(values_holder beta, controller<expressions_tuple> const & ctr, model & fallback,
lhs_rhs_element* eA, lhs_rhs_element* eB, lhs_rhs_element* eC, lhs_rhs_element* ebeta, array const & A, array const & B, array const & C);
@@ -62,41 +63,41 @@ private:
bool check_bounds_;
};
class mproduct_nn : public mproduct
class gemm_nn : public gemm
{
public:
mproduct_nn(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
gemm_nn(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound = false);
};
class mproduct_tn : public mproduct
class gemm_tn : public gemm
{
public:
mproduct_tn(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
gemm_tn(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound = false);
};
class mproduct_nt : public mproduct
class gemm_nt : public gemm
{
public:
mproduct_nt(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
gemm_nt(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound = false);
};
class mproduct_tt : public mproduct
class gemm_tt : public gemm
{
public:
mproduct_tt(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
gemm_tt(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound = false);
};
}
}
#endif

View File

@@ -0,0 +1,61 @@
#ifndef ISAAC_BACKEND_TEMPLATES_MDOT_H
#define ISAAC_BACKEND_TEMPLATES_MDOT_H
#include <vector>
#include "isaac/symbolic/expression.h"
#include "isaac/backend/templates/base.h"
namespace isaac
{
namespace templates
{
struct gemv_parameters : public base::parameters_type
{
gemv_parameters(unsigned int _simd_width,
unsigned int _local_size_0, unsigned int _local_size_1,
unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetch_policy);
unsigned int num_groups_0;
unsigned int num_groups_1;
fetching_policy_type fetch_policy;
};
class gemv : public base_impl<gemv, gemv_parameters>
{
protected:
enum dot_type
{
REDUCE_ROWS,
REDUCE_COLUMNS
};
gemv(gemv::parameters_type const & , dot_type, binding_policy_t);
private:
virtual int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
unsigned int lmem_usage() const;
std::string generate_impl(const char * suffix, expressions_tuple const &, driver::Device const & device, std::vector<mapping_type> const &) const;
public:
virtual std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
private:
dot_type dot_type_;
};
class gemv_n : public gemv
{
public:
gemv_n(gemv::parameters_type const &, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
gemv_n(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
};
class gemv_t : public gemv
{
public:
gemv_t(gemv::parameters_type const &, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
gemv_t(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
};
}
}
#endif

View File

@@ -6,29 +6,32 @@
namespace isaac
{
namespace templates
{
class maxpy_parameters : public base::parameters_type
class ger_parameters : public base::parameters_type
{
public:
maxpy_parameters(unsigned int _simd_width, unsigned int _local_size_0, unsigned int _local_size_1, unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetching_policy);
ger_parameters(unsigned int _simd_width, unsigned int _local_size_0, unsigned int _local_size_1, unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetching_policy);
unsigned int num_groups_0;
unsigned int num_groups_1;
fetching_policy_type fetching_policy;
};
class maxpy : public base_impl<maxpy, maxpy_parameters>
class ger : public base_impl<ger, ger_parameters>
{
private:
int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const;
public:
maxpy(parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
maxpy(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
ger(parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
ger(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
};
}
}
#endif

View File

@@ -1,59 +0,0 @@
#ifndef ISAAC_BACKEND_TEMPLATES_MREDUCTION_H
#define ISAAC_BACKEND_TEMPLATES_MREDUCTION_H
#include <vector>
#include "isaac/symbolic/expression.h"
#include "isaac/backend/templates/base.h"
namespace isaac
{
struct mreduction_parameters : public base::parameters_type
{
mreduction_parameters(unsigned int _simd_width,
unsigned int _local_size_0, unsigned int _local_size_1,
unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetch_policy);
unsigned int num_groups_0;
unsigned int num_groups_1;
fetching_policy_type fetch_policy;
};
class mreduction : public base_impl<mreduction, mreduction_parameters>
{
protected:
enum reduction_type
{
REDUCE_ROWS,
REDUCE_COLUMNS
};
mreduction(mreduction::parameters_type const & , reduction_type, binding_policy_t);
private:
virtual int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
unsigned int lmem_usage() const;
std::string generate_impl(const char * suffix, expressions_tuple const &, driver::Device const & device, std::vector<mapping_type> const &) const;
public:
virtual std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
private:
reduction_type reduction_type_;
};
class mreduction_rows : public mreduction
{
public:
mreduction_rows(mreduction::parameters_type const &, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
mreduction_rows(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
};
class mreduction_cols : public mreduction
{
public:
mreduction_cols(mreduction::parameters_type const &, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
mreduction_cols(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
};
}
#endif

View File

@@ -14,7 +14,7 @@ namespace isaac
class model
{
typedef tools::shared_ptr<base> template_pointer;
typedef tools::shared_ptr<templates::base> template_pointer;
typedef std::vector< template_pointer > templates_container;
private:
@@ -23,8 +23,8 @@ namespace isaac
driver::Program& init(controller<expressions_tuple> const &);
public:
model(expression_type, numeric_type, predictors::random_forest const &, std::vector< tools::shared_ptr<base> > const &, driver::CommandQueue const &);
model(expression_type, numeric_type, base const &, driver::CommandQueue const &);
model(expression_type, numeric_type, predictors::random_forest const &, std::vector< tools::shared_ptr<templates::base> > const &, driver::CommandQueue const &);
model(expression_type, numeric_type, templates::base const &, driver::CommandQueue const &);
void execute(controller<expressions_tuple> const &);
templates_container const & templates() const;
@@ -46,7 +46,7 @@ namespace isaac
model_map_t init_models(driver::CommandQueue const & queue);
model_map_t& models(driver::CommandQueue & queue);
extern std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<base> > fallbacks;
extern std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > fallbacks;
extern std::map<driver::CommandQueue, model_map_t> models_;
}

View File

@@ -31,14 +31,14 @@ enum operation_node_type_family
// BLAS1-type
OPERATOR_UNARY_TYPE_FAMILY,
OPERATOR_BINARY_TYPE_FAMILY,
OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY,
OPERATOR_VECTOR_DOT_TYPE_FAMILY,
// BLAS2-type
OPERATOR_ROWS_REDUCTION_TYPE_FAMILY,
OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY,
OPERATOR_ROWS_DOT_TYPE_FAMILY,
OPERATOR_COLUMNS_DOT_TYPE_FAMILY,
// BLAS3-type
OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
OPERATOR_GEMM_TYPE_FAMILY
};
/** @brief Enumeration for identifying the possible operations */
@@ -119,10 +119,10 @@ enum operation_node_type
OPERATOR_SHIFT_TYPE,
OPERATOR_VDIAG_TYPE,
OPERATOR_MATRIX_PRODUCT_NN_TYPE,
OPERATOR_MATRIX_PRODUCT_TN_TYPE,
OPERATOR_MATRIX_PRODUCT_NT_TYPE,
OPERATOR_MATRIX_PRODUCT_TT_TYPE,
OPERATOR_GEMM_NN_TYPE,
OPERATOR_GEMM_TN_TYPE,
OPERATOR_GEMM_NT_TYPE,
OPERATOR_GEMM_TT_TYPE,
OPERATOR_PAIR_TYPE
};

View File

@@ -128,16 +128,15 @@ template<> struct to_numeric_type<double> { static const numeric_type value = DO
enum expression_type
{
INVALID_EXPRESSION_TYPE,
SCALAR_AXPY_TYPE,
VECTOR_AXPY_TYPE,
MATRIX_AXPY_TYPE,
REDUCTION_TYPE,
ROW_WISE_REDUCTION_TYPE,
COL_WISE_REDUCTION_TYPE,
MATRIX_PRODUCT_NN_TYPE,
MATRIX_PRODUCT_TN_TYPE,
MATRIX_PRODUCT_NT_TYPE,
MATRIX_PRODUCT_TT_TYPE
AXPY_TYPE,
GER_TYPE,
DOT_TYPE,
GEMV_N_TYPE,
GEMV_T_TYPE,
GEMM_NN_TYPE,
GEMM_TN_TYPE,
GEMM_NT_TYPE,
GEMM_TT_TYPE
};
struct slice

View File

@@ -596,17 +596,17 @@ array_expression repmat(array_expression const & A, int_t const & rep1, int_t co
///*--- Reductions ---*/
////---------------------------------------
#define DEFINE_REDUCTION(OP, OPNAME)\
#define DEFINE_DOT(OP, OPNAME)\
array_expression OPNAME(array const & x, int_t axis)\
{\
if(axis < -1 || axis > x.nshape())\
throw std::out_of_range("The axis entry is out of bounds");\
else if(axis==-1)\
return array_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\
return array_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\
else if(axis==0)\
return array_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\
return array_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\
else\
return array_expression(x, invalid_node(), op_element(OPERATOR_ROWS_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\
return array_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\
}\
\
array_expression OPNAME(array_expression const & x, int_t axis)\
@@ -614,20 +614,20 @@ array_expression OPNAME(array_expression const & x, int_t axis)\
if(axis < -1 || axis > x.nshape())\
throw std::out_of_range("The axis entry is out of bounds");\
if(axis==-1)\
return array_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\
return array_expression(x, invalid_node(), op_element(OPERATOR_VECTOR_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(1));\
else if(axis==0)\
return array_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\
return array_expression(x, invalid_node(), op_element(OPERATOR_COLUMNS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[1]));\
else\
return array_expression(x, invalid_node(), op_element(OPERATOR_ROWS_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\
return array_expression(x, invalid_node(), op_element(OPERATOR_ROWS_DOT_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()[0]));\
}
DEFINE_REDUCTION(OPERATOR_ADD_TYPE, sum)
DEFINE_REDUCTION(OPERATOR_ELEMENT_ARGMAX_TYPE, argmax)
DEFINE_REDUCTION(OPERATOR_ELEMENT_MAX_TYPE, max)
DEFINE_REDUCTION(OPERATOR_ELEMENT_MIN_TYPE, min)
DEFINE_REDUCTION(OPERATOR_ELEMENT_ARGMIN_TYPE, argmin)
DEFINE_DOT(OPERATOR_ADD_TYPE, sum)
DEFINE_DOT(OPERATOR_ELEMENT_ARGMAX_TYPE, argmax)
DEFINE_DOT(OPERATOR_ELEMENT_MAX_TYPE, max)
DEFINE_DOT(OPERATOR_ELEMENT_MIN_TYPE, min)
DEFINE_DOT(OPERATOR_ELEMENT_ARGMIN_TYPE, argmin)
#undef DEFINE_REDUCTION
#undef DEFINE_DOT
namespace detail
{
@@ -635,21 +635,21 @@ namespace detail
array_expression matmatprod(array const & A, array const & B)
{
size4 shape(A.shape()[0], B.shape()[1]);
return array_expression(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, OPERATOR_MATRIX_PRODUCT_NN_TYPE), A.context(), A.dtype(), shape);
return array_expression(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, OPERATOR_GEMM_NN_TYPE), A.context(), A.dtype(), shape);
}
array_expression matmatprod(array_expression const & A, array const & B)
{
operation_node_type type = OPERATOR_MATRIX_PRODUCT_NN_TYPE;
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
size4 shape(A.shape()[0], B.shape()[1]);
array_expression::node & A_root = const_cast<array_expression::node &>(A.tree()[A.root()]);
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
if(A_trans){
type = OPERATOR_MATRIX_PRODUCT_TN_TYPE;
type = OPERATOR_GEMM_TN_TYPE;
}
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
array_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs;
return res;
@@ -657,16 +657,16 @@ namespace detail
array_expression matmatprod(array const & A, array_expression const & B)
{
operation_node_type type = OPERATOR_MATRIX_PRODUCT_NN_TYPE;
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
size4 shape(A.shape()[0], B.shape()[1]);
array_expression::node & B_root = const_cast<array_expression::node &>(B.tree()[B.root()]);
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
if(B_trans){
type = OPERATOR_MATRIX_PRODUCT_NT_TYPE;
type = OPERATOR_GEMM_NT_TYPE;
}
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
array_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
if(B_trans) res_root.rhs = B_root.lhs;
return res;
@@ -674,7 +674,7 @@ namespace detail
array_expression matmatprod(array_expression const & A, array_expression const & B)
{
operation_node_type type = OPERATOR_MATRIX_PRODUCT_NN_TYPE;
operation_node_type type = OPERATOR_GEMM_NN_TYPE;
array_expression::node & A_root = const_cast<array_expression::node &>(A.tree()[A.root()]);
array_expression::node & B_root = const_cast<array_expression::node &>(B.tree()[B.root()]);
size4 shape(A.shape()[0], B.shape()[1]);
@@ -682,12 +682,12 @@ namespace detail
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
if(A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_TT_TYPE;
else if(A_trans && !B_trans) type = OPERATOR_MATRIX_PRODUCT_TN_TYPE;
else if(!A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_NT_TYPE;
else type = OPERATOR_MATRIX_PRODUCT_NN_TYPE;
if(A_trans && B_trans) type = OPERATOR_GEMM_TT_TYPE;
else if(A_trans && !B_trans) type = OPERATOR_GEMM_TN_TYPE;
else if(!A_trans && B_trans) type = OPERATOR_GEMM_NT_TYPE;
else type = OPERATOR_GEMM_NN_TYPE;
array_expression res(A, B, op_element(OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
array_expression res(A, B, op_element(OPERATOR_GEMM_TYPE_FAMILY, type), A.context(), A.dtype(), shape);
array_expression::node & res_root = const_cast<array_expression::node &>(res.tree()[res.root()]);
if(A_trans) res_root.lhs = A_root.lhs;
if(B_trans) res_root.rhs = B_root.lhs;

View File

@@ -102,23 +102,23 @@ std::string binary_leaf::evaluate_recursive(leaf_t leaf, std::map<std::string, s
}
mapped_mproduct::mapped_mproduct(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "mproduct"), binary_leaf(info) { }
mapped_gemm::mapped_gemm(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "gemm"), binary_leaf(info) { }
//
mapped_reduction::mapped_reduction(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key) :
mapped_dot::mapped_dot(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key) :
mapped_object(scalartype, id, type_key), binary_leaf(info)
{ }
int_t mapped_reduction::root_idx() const
int_t mapped_dot::root_idx() const
{ return info_.root_idx; }
isaac::array_expression const & mapped_reduction::array_expression() const
isaac::array_expression const & mapped_dot::array_expression() const
{ return *info_.array_expression; }
array_expression::node mapped_reduction::root_node() const
array_expression::node mapped_dot::root_node() const
{ return array_expression().tree()[root_idx()]; }
bool mapped_reduction::is_index_reduction() const
bool mapped_dot::is_index_dot() const
{
op_element const & op = root_op();
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
@@ -127,17 +127,17 @@ bool mapped_reduction::is_index_reduction() const
|| op.type==OPERATOR_ELEMENT_ARGMIN_TYPE;
}
op_element mapped_reduction::root_op() const
op_element mapped_dot::root_op() const
{
return info_.array_expression->tree()[info_.root_idx].op;
}
//
mapped_scalar_reduction::mapped_scalar_reduction(std::string const & scalartype, unsigned int id, node_info info) : mapped_reduction(scalartype, id, info, "scalar_reduction"){ }
mapped_scalar_dot::mapped_scalar_dot(std::string const & scalartype, unsigned int id, node_info info) : mapped_dot(scalartype, id, info, "scalar_dot"){ }
//
mapped_mreduction::mapped_mreduction(std::string const & scalartype, unsigned int id, node_info info) : mapped_reduction(scalartype, id, info, "mreduction") { }
mapped_gemv::mapped_gemv(std::string const & scalartype, unsigned int id, node_info info) : mapped_dot(scalartype, id, info, "gemv") { }
//
void mapped_host_scalar::preprocess(std::string & str) const

View File

@@ -10,15 +10,15 @@ namespace detail
bool is_scalar_reduction(array_expression::node const & node)
bool is_scalar_dot(array_expression::node const & node)
{
return node.op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY;
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY;
}
bool is_vector_reduction(array_expression::node const & node)
bool is_vector_dot(array_expression::node const & node)
{
return node.op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|| node.op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY;
return node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY;
}
bool is_assignment(op_element const & op)
@@ -75,10 +75,10 @@ namespace detail
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|| op.type==OPERATOR_OUTER_PROD_TYPE
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
|| op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_GEMM_TYPE_FAMILY
;
}
@@ -214,10 +214,10 @@ const char * evaluate(operation_node_type type)
case OPERATOR_ELEMENT_MIN_TYPE : return "min";
//Binary
case OPERATOR_MATRIX_PRODUCT_NN_TYPE : return "prodNN";
case OPERATOR_MATRIX_PRODUCT_TN_TYPE : return "prodTN";
case OPERATOR_MATRIX_PRODUCT_NT_TYPE : return "prodNT";
case OPERATOR_MATRIX_PRODUCT_TT_TYPE : return "prodTT";
case OPERATOR_GEMM_NN_TYPE : return "prodNN";
case OPERATOR_GEMM_TN_TYPE : return "prodTN";
case OPERATOR_GEMM_NT_TYPE : return "prodNT";
case OPERATOR_GEMM_TT_TYPE : return "prodTT";
case OPERATOR_VDIAG_TYPE : return "vdiag";
case OPERATOR_MATRIX_DIAG_TYPE : return "mdiag";
case OPERATOR_MATRIX_ROW_TYPE : return "row";

View File

@@ -1,4 +1,4 @@
#include "isaac/backend/templates/vaxpy.h"
#include "isaac/backend/templates/axpy.h"
#include "isaac/backend/keywords.h"
#include "isaac/driver/backend.h"
#include "isaac/tools/make_map.hpp"
@@ -8,23 +8,24 @@
namespace isaac
{
namespace templates
{
vaxpy_parameters::vaxpy_parameters(unsigned int _simd_width,
axpy_parameters::axpy_parameters(unsigned int _simd_width,
unsigned int _group_size, unsigned int _num_groups,
fetching_policy_type _fetching_policy) :
base::parameters_type(_simd_width, _group_size, 1, 1), num_groups(_num_groups), fetching_policy(_fetching_policy)
{ }
int vaxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
int axpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
{
if (p_.fetching_policy==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
std::string vaxpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
std::string axpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
{
driver::backend_type backend = device.backend();
std::string _size_t = size_type(device);
@@ -90,25 +91,24 @@ std::string vaxpy::generate_impl(const char * suffix, expressions_tuple const &
return stream.str();
}
vaxpy::vaxpy(vaxpy_parameters const & parameters,
axpy::axpy(axpy_parameters const & parameters,
binding_policy_t binding_policy) :
base_impl<vaxpy, vaxpy_parameters>(parameters, binding_policy)
base_impl<axpy, axpy_parameters>(parameters, binding_policy)
{}
vaxpy::vaxpy(unsigned int simd, unsigned int ls, unsigned int ng,
axpy::axpy(unsigned int simd, unsigned int ls, unsigned int ng,
fetching_policy_type fetch, binding_policy_t bind):
base_impl<vaxpy, vaxpy_parameters>(vaxpy_parameters(simd,ls,ng,fetch), bind)
base_impl<axpy, axpy_parameters>(axpy_parameters(simd,ls,ng,fetch), bind)
{}
std::vector<int_t> vaxpy::input_sizes(expressions_tuple const & expressions) const
std::vector<int_t> axpy::input_sizes(expressions_tuple const & expressions) const
{
size4 shape = static_cast<array_expression const *>(expressions.data().front().get())->shape();
int_t size = static_cast<array_expression const *>(expressions.data().front().get())->shape()[0];
return tools::make_vector<int_t>() << std::max(shape[0], shape[1]);
}
void vaxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
void axpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
{
expressions_tuple const & expressions = controller.x();
//Size
@@ -135,3 +135,4 @@ void vaxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, con
}
}

View File

@@ -2,11 +2,11 @@
#include "isaac/array.h"
#include "isaac/backend/keywords.h"
#include "isaac/backend/templates/vaxpy.h"
#include "isaac/backend/templates/reduction.h"
#include "isaac/backend/templates/maxpy.h"
#include "isaac/backend/templates/mreduction.h"
#include "isaac/backend/templates/mproduct.h"
#include "isaac/backend/templates/axpy.h"
#include "isaac/backend/templates/dot.h"
#include "isaac/backend/templates/ger.h"
#include "isaac/backend/templates/gemv.h"
#include "isaac/backend/templates/gemm.h"
#include "isaac/backend/templates/base.h"
#include "isaac/backend/parse.h"
#include "isaac/exception/operation_not_supported.h"
@@ -17,6 +17,8 @@
namespace isaac
{
namespace templates
{
base::parameters_type::parameters_type(unsigned int _simd_width, int_t _local_size_1, int_t _local_size_2, int_t _num_kernels) : simd_width(_simd_width), local_size_0(_local_size_1), local_size_1(_local_size_2), num_kernels(_num_kernels)
{ }
@@ -102,12 +104,12 @@ void base::map_functor::operator()(isaac::array_expression const & array_express
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&array_expression, root_idx, &mapping_)));
else if (root_node.op.type==OPERATOR_MATRIX_COLUMN_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&array_expression, root_idx, &mapping_)));
else if (detail::is_scalar_reduction(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_scalar_reduction>(&array_expression, root_idx, &mapping_)));
else if (detail::is_vector_reduction(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mreduction>(&array_expression, root_idx, &mapping_)));
else if (root_node.op.type_family == OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mproduct>(&array_expression, root_idx, &mapping_)));
else if (detail::is_scalar_dot(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_scalar_dot>(&array_expression, root_idx, &mapping_)));
else if (detail::is_vector_dot(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_gemv>(&array_expression, root_idx, &mapping_)));
else if (root_node.op.type_family == OPERATOR_GEMM_TYPE_FAMILY)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_gemm>(&array_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_REPEAT_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&array_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
@@ -198,7 +200,7 @@ void base::set_arguments_functor::operator()(isaac::array_expression const & arr
set_arguments(root_node.rhs);
}
void base::compute_reduction(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op)
void base::compute_dot(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op)
{
if (detail::is_elementwise_function(op))
os << acc << "=" << evaluate(op.type) << "(" << acc << "," << cur << ");" << std::endl;
@@ -206,7 +208,7 @@ void base::compute_reduction(kernel_generation_stream & os, std::string acc, std
os << acc << "= (" << acc << ")" << evaluate(op.type) << "(" << cur << ");" << std::endl;
}
void base::compute_index_reduction(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op)
void base::compute_index_dot(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op)
{
// os << acc << " = " << cur_value << ">" << acc_value << "?" << cur << ":" << acc << ";" << std::endl;
os << acc << "= select(" << acc << "," << cur << "," << cur_value << ">" << acc_value << ");" << std::endl;
@@ -259,7 +261,7 @@ std::string base::neutral_element(op_element const & op, driver::backend_type ba
case OPERATOR_ELEMENT_MIN_TYPE : return INF;
case OPERATOR_ELEMENT_ARGMIN_TYPE : return INF;
default: throw operation_not_supported_exception("Unsupported reduction operator : no neutral element known");
default: throw operation_not_supported_exception("Unsupported dot operator : no neutral element known");
}
}
@@ -399,14 +401,14 @@ std::pair<int_t, int_t> base::matrix_size(array_expression::node const & node)
return std::make_pair(node.lhs.array->shape()[0],node.lhs.array->shape()[1]);
}
bool base::is_reduction(array_expression::node const & node)
bool base::is_dot(array_expression::node const & node)
{
return node.op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|| node.op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|| node.op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY;
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY;
}
bool base::is_index_reduction(op_element const & op)
bool base::is_index_dot(op_element const & op)
{
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
|| op.type==OPERATOR_ELEMENT_ARGMAX_TYPE
@@ -566,10 +568,11 @@ int base_impl<TType, PType>::is_invalid(expressions_tuple const & expressions, d
return is_invalid_impl(device, expressions);
}
template class base_impl<vaxpy, vaxpy_parameters>;
template class base_impl<reduction, reduction_parameters>;
template class base_impl<maxpy, maxpy_parameters>;
template class base_impl<mreduction, mreduction_parameters>;
template class base_impl<mproduct, mproduct_parameters>;
template class base_impl<axpy, axpy_parameters>;
template class base_impl<dot, dot_parameters>;
template class base_impl<ger, ger_parameters>;
template class base_impl<gemv, gemv_parameters>;
template class base_impl<gemm, gemm_parameters>;
}
}

View File

@@ -1,5 +1,5 @@
#include <iostream>
#include "isaac/backend/templates/reduction.h"
#include "isaac/backend/templates/dot.h"
#include <CL/cl.hpp>
#include "isaac/tools/to_string.hpp"
#include "isaac/tools/make_map.hpp"
@@ -7,13 +7,14 @@
#include "isaac/backend/keywords.h"
namespace isaac
{
reduction_parameters::reduction_parameters(unsigned int _simd_width,
namespace templates
{
dot_parameters::dot_parameters(unsigned int _simd_width,
unsigned int _group_size, unsigned int _num_groups,
fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _group_size, 1, 2), num_groups(_num_groups), fetching_policy(_fetching_policy)
{ }
unsigned int reduction::lmem_usage(expressions_tuple const & expressions) const
unsigned int dot::lmem_usage(expressions_tuple const & expressions) const
{
unsigned int res = 0;
for(const auto & elem : expressions.data())
@@ -24,14 +25,14 @@ unsigned int reduction::lmem_usage(expressions_tuple const & expressions) const
return res;
}
int reduction::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
int dot::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
{
if (p_.fetching_policy==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
inline void reduction::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_reduction*> exprs,
inline void dot::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_dot*> exprs,
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const
{
stream << "#pragma unroll" << std::endl;
@@ -44,26 +45,26 @@ inline void reduction::reduce_1d_local_memory(kernel_generation_stream & stream,
stream.inc_tab();
for (auto & expr : exprs)
if (expr->is_index_reduction())
compute_index_reduction(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]")
if (expr->is_index_dot())
compute_index_dot(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]")
, expr->process(buf_value_str+"[lid]"), expr->process(buf_value_str+"[lid+stride]"),
expr->root_op());
else
compute_reduction(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]"), expr->root_op());
compute_dot(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]"), expr->root_op());
stream.dec_tab();
stream << "}" << std::endl;
stream.dec_tab();
stream << "}" << std::endl;
}
std::string reduction::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
std::string dot::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
{
kernel_generation_stream stream;
std::vector<mapped_scalar_reduction*> exprs;
std::vector<mapped_scalar_dot*> exprs;
for (const auto & mapping : mappings)
for (mapping_type::const_iterator iit = mapping.begin(); iit != mapping.end(); ++iit)
if (mapped_scalar_reduction * p = dynamic_cast<mapped_scalar_reduction*>(iit->second.get()))
if (mapped_scalar_dot * p = dynamic_cast<mapped_scalar_dot*>(iit->second.get()))
exprs.push_back(p);
std::size_t N = exprs.size();
driver::backend_type backend = device.backend();
@@ -73,7 +74,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
for (unsigned int k = 0; k < N; ++k)
{
std::string numeric_type = numeric_type_to_string(lhs_most(exprs[k]->array_expression().tree(), exprs[k]->array_expression().root()).lhs.dtype);
if (exprs[k]->is_index_reduction())
if (exprs[k]->is_index_dot())
{
arguments += exprs[k]->process(Global(backend).get() + " unsigned int* #name_temp, ");
arguments += exprs[k]->process(Global(backend).get() + " " + tools::to_string(numeric_type) + "* #name_temp_value, ");
@@ -112,7 +113,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
for (unsigned int k = 0; k < N; ++k)
{
if (exprs[k]->is_index_reduction())
if (exprs[k]->is_index_dot())
{
stream << exprs[k]->process(Local(backend).get() + " #scalartype #name_buf_value[" + tools::to_string(p_.local_size_0) + "];") << std::endl;
stream << exprs[k]->process("#scalartype #name_acc_value = " + neutral_element(exprs[k]->root_op(), backend, "#scalartype") + ";") << std::endl;
@@ -156,11 +157,11 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
accessors["matrix_diag"] = str[a];
accessors["array0"] = "#namereg";
std::string value = elem->evaluate_recursive(LHS_NODE_TYPE, accessors);
if (elem->is_index_reduction())
compute_index_reduction(stream, elem->process("#name_acc"), "i*" + tools::to_string(simd_width) + "+"
if (elem->is_index_dot())
compute_index_dot(stream, elem->process("#name_acc"), "i*" + tools::to_string(simd_width) + "+"
+ tools::to_string(a), elem->process("#name_acc_value"), value,elem->root_op());
else
compute_reduction(stream, elem->process("#name_acc"), value,elem->root_op());
compute_dot(stream, elem->process("#name_acc"), value,elem->root_op());
}
}
});
@@ -168,7 +169,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
//Fills local memory
for (unsigned int k = 0; k < N; ++k)
{
if (exprs[k]->is_index_reduction())
if (exprs[k]->is_index_dot())
stream << exprs[k]->process("#name_buf_value[lid] = #name_acc_value;") << std::endl;
stream << exprs[k]->process("#name_buf[lid] = #name_acc;") << std::endl;
}
@@ -182,7 +183,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
stream.inc_tab();
for (unsigned int k = 0; k < N; ++k)
{
if (exprs[k]->is_index_reduction())
if (exprs[k]->is_index_dot())
stream << exprs[k]->process("#name_temp_value[gpid] = #name_buf_value[0];") << std::endl;
stream << exprs[k]->process("#name_temp[gpid] = #name_buf[0];") << std::endl;
}
@@ -205,9 +206,9 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl;
stream << "unsigned int lsize = " <<LocalSize0(backend) << ";" << std::endl;
for (mapped_scalar_reduction* e: exprs)
for (mapped_scalar_dot* e: exprs)
{
if (e->is_index_reduction())
if (e->is_index_dot())
{
stream << e->process(Local(backend).get() + " unsigned int #name_buf[" + tools::to_string(p_.local_size_0) + "];");
stream << e->process("unsigned int #name_acc = 0;") << std::endl;
@@ -224,18 +225,18 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
stream << "for(unsigned int i = lid; i < " << p_.num_groups << "; i += lsize)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
for (mapped_scalar_reduction* e: exprs)
if (e->is_index_reduction())
compute_index_reduction(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->process("#name_acc_value"),e->process("#name_temp_value[i]"),e->root_op());
for (mapped_scalar_dot* e: exprs)
if (e->is_index_dot())
compute_index_dot(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->process("#name_acc_value"),e->process("#name_temp_value[i]"),e->root_op());
else
compute_reduction(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->root_op());
compute_dot(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->root_op());
stream.dec_tab();
stream << "}" << std::endl;
for (unsigned int k = 0; k < N; ++k)
{
if (exprs[k]->is_index_reduction())
if (exprs[k]->is_index_dot())
stream << exprs[k]->process("#name_buf_value[lid] = #name_acc_value;") << std::endl;
stream << exprs[k]->process("#name_buf[lid] = #name_acc;") << std::endl;
}
@@ -248,7 +249,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
stream << "{" << std::endl;
stream.inc_tab();
std::map<std::string, std::string> accessors;
accessors["scalar_reduction"] = "#name_buf[0]";
accessors["scalar_dot"] = "#name_buf[0]";
accessors["array0"] = "#pointer[#start]";
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
stream.dec_tab();
@@ -260,23 +261,23 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
return stream.str();
}
reduction::reduction(reduction::parameters_type const & parameters,
binding_policy_t binding) : base_impl<reduction, reduction_parameters>(parameters, binding)
dot::dot(dot::parameters_type const & parameters,
binding_policy_t binding) : base_impl<dot, dot_parameters>(parameters, binding)
{ }
reduction::reduction(unsigned int simd, unsigned int ls, unsigned int ng,
dot::dot(unsigned int simd, unsigned int ls, unsigned int ng,
fetching_policy_type fetch, binding_policy_t bind):
base_impl<reduction, reduction_parameters>(reduction_parameters(simd,ls,ng,fetch), bind)
base_impl<dot, dot_parameters>(dot_parameters(simd,ls,ng,fetch), bind)
{}
std::vector<int_t> reduction::input_sizes(expressions_tuple const & expressions) const
std::vector<int_t> dot::input_sizes(expressions_tuple const & expressions) const
{
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *(expressions.data().front()), false);
int_t N = vector_size(lhs_most(expressions.data().front()->tree(), reductions_idx[0]));
std::vector<size_t> dots_idx = filter_nodes(&is_dot, *(expressions.data().front()), false);
int_t N = vector_size(lhs_most(expressions.data().front()->tree(), dots_idx[0]));
return tools::make_vector<int_t>() << N;
}
void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
void dot::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
{
expressions_tuple const & expressions = controller.x();
@@ -290,12 +291,12 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
return;
}
std::vector<array_expression::node const *> reductions;
std::vector<array_expression::node const *> dots;
for (const auto & elem : expressions.data())
{
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *elem, false);
for (auto & reductions_idx_itt : reductions_idx)
reductions.push_back(&(elem)->tree()[reductions_idx_itt]);
std::vector<size_t> dots_idx = filter_nodes(&is_dot, *elem, false);
for (auto & dots_idx_itt : dots_idx)
dots.push_back(&(elem)->tree()[dots_idx_itt]);
}
//Kernel
@@ -321,9 +322,9 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
//Temporary buffers
unsigned int i = 0;
unsigned int j = 0;
for (std::vector<array_expression::node const *>::const_iterator it = reductions.begin(); it != reductions.end(); ++it)
for (std::vector<array_expression::node const *>::const_iterator it = dots.begin(); it != dots.end(); ++it)
{
if (is_index_reduction((*it)->op))
if (is_index_dot((*it)->op))
{
if (tmpidx_.size() <= j)
tmpidx_.push_back(driver::Buffer(context, p_.num_groups*4));
@@ -343,3 +344,4 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
}
}
}

View File

@@ -1,5 +1,5 @@
#include "isaac/array.h"
#include "isaac/backend/templates/mproduct.h"
#include "isaac/backend/templates/gemm.h"
#include "isaac/backend/keywords.h"
#include "isaac/model/model.h"
#include "isaac/symbolic/preset.h"
@@ -10,8 +10,10 @@
namespace isaac
{
namespace templates
{
mproduct_parameters::mproduct_parameters(unsigned int simd_width
gemm_parameters::gemm_parameters(unsigned int simd_width
, int_t local_size_0, int_t KL, int_t local_size_1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type A_fetching_policy, fetching_policy_type B_fetching_policy
@@ -21,7 +23,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
mL(ms*local_size_0), nL(ns*local_size_1){}
unsigned int mproduct::lmem_usage(expressions_tuple const & expressions) const
unsigned int gemm::lmem_usage(expressions_tuple const & expressions) const
{
isaac::array_expression const & array_expression = (*expressions.data().front());
numeric_type numeric_t = lhs_most(array_expression.tree(), array_expression.root()).lhs.dtype;
@@ -32,7 +34,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return N*size_of(numeric_t);
}
unsigned int mproduct::registers_usage(expressions_tuple const & expressions) const
unsigned int gemm::registers_usage(expressions_tuple const & expressions) const
{
isaac::array_expression const & array_expression = (*expressions.data().front());
numeric_type numeric_t = lhs_most(array_expression.tree(), array_expression.root()).lhs.dtype;
@@ -41,7 +43,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return N*size_of(numeric_t);
}
int mproduct::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const
int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const
{
std::vector<int_t> MNK = input_sizes(expressions);
int_t M = MNK[0]; int_t N = MNK[1];
@@ -95,7 +97,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return TEMPLATE_VALID;
}
std::string mproduct::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const &) const
std::string gemm::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const &) const
{
using std::string;
using tools::to_string;
@@ -437,7 +439,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
#undef VST0RE
}
void mproduct::enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K,
void gemm::enqueue_block(driver::CommandQueue & /*queue*/, int_t M, int_t N, int_t K,
array const & A, array const & B, array const & C,
value_scalar const & alpha, value_scalar const & beta,
driver::Program & program, const char * suffix, execution_options_type const & options)
@@ -516,7 +518,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
}
}
array mproduct::create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap)
array gemm::create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap)
{
slice s0(s0_0, s0_1);
slice s1(s1_0, s1_1);
@@ -525,7 +527,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return array(M, s0, s1);
}
std::vector<int_t> mproduct::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments) const
std::vector<int_t> gemm::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments) const
{
isaac::array_expression & array_expression = (*expressions.data().front());
array_expression::container_type & array = array_expression.tree();
@@ -537,26 +539,26 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return {M, N, K};
}
mproduct::mproduct(mproduct_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<mproduct, mproduct_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
gemm::gemm(gemm_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<gemm, gemm_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
{
if(A_trans_=='N' && B_trans_=='N') type_ = MATRIX_PRODUCT_NN_TYPE;
else if(A_trans_=='T' && B_trans_=='N') type_ = MATRIX_PRODUCT_TN_TYPE;
else if(A_trans_=='N' && B_trans_=='T') type_ = MATRIX_PRODUCT_NT_TYPE;
else if(A_trans_=='T' && B_trans_=='T') type_ = MATRIX_PRODUCT_TT_TYPE;
if(A_trans_=='N' && B_trans_=='N') type_ = GEMM_NN_TYPE;
else if(A_trans_=='T' && B_trans_=='N') type_ = GEMM_TN_TYPE;
else if(A_trans_=='N' && B_trans_=='T') type_ = GEMM_NT_TYPE;
else if(A_trans_=='T' && B_trans_=='T') type_ = GEMM_TT_TYPE;
else throw;
}
std::vector<int_t> mproduct::input_sizes(expressions_tuple const & expressions) const
std::vector<int_t> gemm::input_sizes(expressions_tuple const & expressions) const
{
symbolic::preset::gemm::args dummy;
return infos(expressions, dummy);
}
void mproduct::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
void gemm::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
{
using namespace tools;
mproduct & fallback = (mproduct&)fallback_base;
gemm & fallback = (gemm&)fallback_base;
expressions_tuple const & expressions = ctr.x();
@@ -579,8 +581,6 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
int_t ldstrideA = pA->stride()[0];
int_t ldstrideB = pB->stride()[0];
int_t ldstrideC = pC->stride()[0];
int_t ldstartA = pA->start()[0];
int_t ldstartB = pB->start()[0];
numeric_type dtype = args.C->dtype;
@@ -613,40 +613,41 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
}
//
mproduct_nn::mproduct_nn(unsigned int simd
gemm_nn::gemm_nn(unsigned int simd
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) :
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N')
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N')
{ }
//
mproduct_tn::mproduct_tn(unsigned int simd
gemm_tn::gemm_tn(unsigned int simd
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) :
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N')
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N')
{ }
//
mproduct_nt::mproduct_nt(unsigned int simd
gemm_nt::gemm_nt(unsigned int simd
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) :
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T')
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T')
{ }
//
mproduct_tt::mproduct_tt(unsigned int simd
gemm_tt::gemm_tt(unsigned int simd
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) :
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T')
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T')
{ }
}
}

View File

@@ -1,48 +1,50 @@
#include <iostream>
#include "isaac/backend/stream.h"
#include "isaac/backend/keywords.h"
#include "isaac/backend/templates/mreduction.h"
#include "isaac/backend/templates/gemv.h"
#include "isaac/tools/to_string.hpp"
#include "isaac/tools/make_map.hpp"
#include "isaac/tools/make_vector.hpp"
namespace isaac
{
namespace templates
{
mreduction_parameters::mreduction_parameters(unsigned int _simd_width,
gemv_parameters::gemv_parameters(unsigned int _simd_width,
unsigned int _local_size_0, unsigned int _local_size_1,
unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetch_policy): base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1),
num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetch_policy(_fetch_policy) { }
int mreduction::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
int gemv::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
{
if(reduction_type_==REDUCE_ROWS && p_.simd_width>1)
if(dot_type_==REDUCE_ROWS && p_.simd_width>1)
return TEMPLATE_INVALID_SIMD_WIDTH;
if (p_.fetch_policy==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
unsigned int mreduction::lmem_usage() const
unsigned int gemv::lmem_usage() const
{
return (p_.local_size_0+1)*p_.local_size_1;
}
std::string mreduction::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
std::string gemv::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
{
using tools::to_string;
std::vector<mapped_mreduction*> reductions;
std::vector<mapped_gemv*> dots;
expressions_tuple::data_type::const_iterator sit;
std::vector<mapping_type>::const_iterator mit;
for (mit = mappings.begin(), sit = expressions.data().begin(); mit != mappings.end(); ++mit, ++sit)
{
array_expression const & first_expression = *expressions.data().front();
std::vector<size_t> idx = filter_nodes(&is_reduction, first_expression, false);
std::vector<size_t> idx = filter_nodes(&is_dot, first_expression, false);
for (auto & elem : idx)
reductions.push_back((mapped_mreduction*)(mit->at(mapping_key(elem, PARENT_NODE_TYPE)).get()));
dots.push_back((mapped_gemv*)(mit->at(mapping_key(elem, PARENT_NODE_TYPE)).get()));
}
kernel_generation_stream stream;
@@ -54,10 +56,10 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
strcat(name[1], suffix);
std::string arguments = _size_t + " M, " + _size_t + " N, " ;
for (const auto & e : reductions)
for (const auto & e : dots)
{
std::string numeric_type = numeric_type_to_string(lhs_most(e->array_expression().tree(), e->array_expression().root()).lhs.dtype);
if (e->is_index_reduction())
if (e->is_index_dot())
{
arguments += e->process(Global(backend).get() + " unsigned int* #name_temp, ");
arguments += e->process(Global(backend).get() + " " + to_string(numeric_type) + "* #name_temp_value,");
@@ -87,7 +89,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
unsigned int local_size_0_ld = p_.local_size_0;
std::string local_size_0_ld_str = to_string(local_size_0_ld);
for (const auto & e : reductions)
for (const auto & e : dots)
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
stream << "" << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl;
@@ -104,7 +106,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
stream.inc_tab();
for (const auto & e : reductions)
for (const auto & e : dots)
stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl;
stream << "if (r < M)" << std::endl;
@@ -116,10 +118,10 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
std::string data_type = append_width("#scalartype",simd_width);
for (const auto & e : reductions)
for (const auto & e : dots)
{
std::map<std::string, std::string> accessors;
if(reduction_type_==REDUCE_COLUMNS)
if(dot_type_==REDUCE_COLUMNS)
{
accessors["array2"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "c*#stride1", "#pointer + r*#ld", backend)+";";
accessors["repeat"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "(c%#tuplearg0)*#stride", "#pointer + (r%#tuplearg1)*#stride ", backend)+";";
@@ -141,20 +143,20 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
str[a] = access_vector_type("#namereg",a);
for (auto & elem : reductions)
for (auto & elem : dots)
for (unsigned int a = 0; a < simd_width; ++a)
{
std::string value = elem->evaluate_recursive(LHS_NODE_TYPE, {{"array2", str[a]}, {"repeat", str[a]}, {"array0", "#namereg"}});
if (elem->is_index_reduction())
compute_index_reduction(stream, elem->process("#name_acc"), "c*"+to_string(simd_width) + to_string(a), elem->process("#name_acc_value"), value, elem->root_op());
if (elem->is_index_dot())
compute_index_dot(stream, elem->process("#name_acc"), "c*"+to_string(simd_width) + to_string(a), elem->process("#name_acc_value"), value, elem->root_op());
else
compute_reduction(stream, elem->process("#name_acc"), value,elem->root_op());
compute_dot(stream, elem->process("#name_acc"), value,elem->root_op());
}
});
stream.dec_tab();
stream << "}" << std::endl;
for (auto & expr : reductions)
for (auto & expr : dots)
stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl;
stream << "#pragma unroll" << std::endl;
@@ -167,13 +169,13 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "{" << std::endl;
stream.inc_tab();
for (auto & e : reductions)
if (e->is_index_reduction())
compute_index_reduction(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
for (auto & e : dots)
if (e->is_index_dot())
compute_index_dot(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
, e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
, e->root_op());
else
compute_reduction(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
compute_dot(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
stream.dec_tab();
stream << "}" << std::endl;
@@ -188,15 +190,15 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
if(p_.num_groups_0==1)
{
std::map<std::string, std::string> accessors;
accessors["mreduction"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
accessors["gemv"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
accessors["array1"] = "#pointer[r*#stride]";
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
}
else
{
for (mapped_reduction const * e : reductions)
for (mapped_dot const * e : dots)
{
if (e->is_index_reduction())
if (e->is_index_dot())
stream << e->process("#name_temp_value[r + M*gpid0] = #name_buf_value[lid1*" + local_size_0_ld_str + "];") << std::endl;
stream << e->process("#name_temp[r + M*gpid0] = #name_buf[lid1*" + local_size_0_ld_str + "];") << std::endl;
}
@@ -230,7 +232,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
{"array2", "#pointer += #start1 + #start2*#ld; "
"#ld *= #nldstride; "}}, expressions, mappings);
for (const auto & e : reductions)
for (const auto & e : dots)
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
stream << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl;
@@ -246,7 +248,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
stream.inc_tab();
for (const auto & e : reductions)
for (const auto & e : dots)
stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl;
stream << "if (r < M)" << std::endl;
@@ -256,8 +258,8 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "for(" << _size_t << " c = lid0; c < " << p_.num_groups_0 << "; c += lsize0){" << std::endl;
stream.inc_tab();
for (mapped_reduction* e: reductions)
compute_reduction(stream, e->process("#name_acc"), e->process("#name_temp[r + M*c]"), e->root_op());
for (mapped_dot* e: dots)
compute_dot(stream, e->process("#name_acc"), e->process("#name_temp[r + M*c]"), e->root_op());
stream.dec_tab();
stream << "}" << std::endl;
@@ -266,7 +268,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream.dec_tab();
stream << "}" << std::endl;
for (auto & expr : reductions)
for (auto & expr : dots)
stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl;
stream << "#pragma unroll" << std::endl;
@@ -279,13 +281,13 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "{" << std::endl;
stream.inc_tab();
for (auto & e : reductions)
if (e->is_index_reduction())
compute_index_reduction(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
for (auto & e : dots)
if (e->is_index_dot())
compute_index_dot(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
, e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
, e->root_op());
else
compute_reduction(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
compute_dot(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
stream.dec_tab();
stream << "}" << std::endl;
@@ -299,7 +301,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream.inc_tab();
std::map<std::string, std::string> accessors;
accessors["mreduction"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
accessors["gemv"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
accessors["array1"] = "#pointer[r*#stride]";
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
@@ -317,38 +319,38 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
return stream.str();
}
mreduction::mreduction(mreduction::parameters_type const & parameters,
mreduction::reduction_type rtype,
gemv::gemv(gemv::parameters_type const & parameters,
gemv::dot_type rtype,
binding_policy_t binding_policy) :
base_impl<mreduction, mreduction_parameters>(parameters, binding_policy),
reduction_type_(rtype){ }
base_impl<gemv, gemv_parameters>(parameters, binding_policy),
dot_type_(rtype){ }
std::vector<int_t> mreduction::input_sizes(expressions_tuple const & expressions) const
std::vector<int_t> gemv::input_sizes(expressions_tuple const & expressions) const
{
array_expression const & first_expression = *expressions.data().front();
std::vector<std::size_t> idx = filter_nodes(&is_reduction, first_expression, false);
std::vector<std::size_t> idx = filter_nodes(&is_dot, first_expression, false);
std::pair<int_t, int_t> MN = matrix_size(lhs_most(first_expression.tree(), idx[0]));
if(reduction_type_==REDUCE_COLUMNS)
if(dot_type_==REDUCE_COLUMNS)
std::swap(MN.first,MN.second);
return tools::make_vector<int_t>() << MN.first << MN.second;
}
void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
void gemv::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
{
expressions_tuple const & expressions = controller.x();
driver::Context const & context = expressions.context();
std::vector<int_t> MN = input_sizes(expressions);
std::vector<array_expression::node const *> reductions;
std::vector<array_expression::node const *> dots;
for (const auto & e : expressions.data())
{
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *e, false);
for (auto & r : reductions_idx)
reductions.push_back(&(e)->tree()[r]);
std::vector<size_t> dots_idx = filter_nodes(&is_dot, *e, false);
for (auto & r : dots_idx)
dots.push_back(&(e)->tree()[r]);
}
//Fallback
if(reduction_type_==REDUCE_COLUMNS && p_.simd_width>1 && requires_fallback(expressions))
if(dot_type_==REDUCE_COLUMNS && p_.simd_width>1 && requires_fallback(expressions))
{
fallback.enqueue(queue, program, "fallback", fallback, controller);
return;
@@ -381,9 +383,9 @@ void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program
//Temporary buffers
unsigned int i = 0;
unsigned int j = 0;
for (auto const & r : reductions)
for (auto const & r : dots)
{
if (is_index_reduction(r->op))
if (is_index_dot(r->op))
{
if (tmpidx.size() <= j)
tmpidx.push_back(driver::Buffer(context, p_.num_groups_0*M*4));
@@ -405,24 +407,25 @@ void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program
controller.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
}
mreduction_rows::mreduction_rows(mreduction_parameters const & parameters,
gemv_n::gemv_n(gemv_parameters const & parameters,
binding_policy_t binding_policy):
mreduction(parameters, REDUCE_ROWS, binding_policy){}
gemv(parameters, REDUCE_ROWS, binding_policy){}
mreduction_rows::mreduction_rows(unsigned int simd, unsigned int ls1, unsigned int ls2,
gemv_n::gemv_n(unsigned int simd, unsigned int ls1, unsigned int ls2,
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind):
mreduction(mreduction_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_ROWS, bind)
gemv(gemv_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_ROWS, bind)
{}
mreduction_cols::mreduction_cols(mreduction::parameters_type const & parameters,
gemv_t::gemv_t(gemv::parameters_type const & parameters,
binding_policy_t binding_policy):
mreduction(parameters, REDUCE_COLUMNS, binding_policy){}
gemv(parameters, REDUCE_COLUMNS, binding_policy){}
mreduction_cols::mreduction_cols(unsigned int simd, unsigned int ls1, unsigned int ls2,
gemv_t::gemv_t(unsigned int simd, unsigned int ls1, unsigned int ls2,
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind):
mreduction(mreduction_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_COLUMNS, bind)
gemv(gemv_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_COLUMNS, bind)
{}
}
}

View File

@@ -1,4 +1,4 @@
#include "isaac/backend/templates/maxpy.h"
#include "isaac/backend/templates/ger.h"
#include "isaac/tools/make_map.hpp"
#include "isaac/tools/make_vector.hpp"
#include "isaac/symbolic/io.h"
@@ -7,15 +7,17 @@
namespace isaac
{
namespace templates
{
maxpy_parameters::maxpy_parameters(unsigned int _simd_width,
ger_parameters::ger_parameters(unsigned int _simd_width,
unsigned int _local_size_0, unsigned int _local_size_1,
unsigned int _num_groups_0, unsigned int _num_groups_1,
fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1), num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetching_policy(_fetching_policy){ }
int maxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
int ger::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
{
if (p_.simd_width>1)
return TEMPLATE_INVALID_SIMD_WIDTH;
@@ -24,7 +26,7 @@ int maxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) co
return TEMPLATE_VALID;
}
std::string maxpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
std::string ger::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
{
kernel_generation_stream stream;
std::string _size_t = size_type(device);
@@ -95,23 +97,23 @@ std::string maxpy::generate_impl(const char * suffix, expressions_tuple const &
return stream.str();
}
maxpy::maxpy(parameters_type const & parameters, binding_policy_t binding_policy) :
base_impl<maxpy, maxpy_parameters>(parameters, binding_policy){ }
ger::ger(parameters_type const & parameters, binding_policy_t binding_policy) :
base_impl<ger, ger_parameters>(parameters, binding_policy){ }
maxpy::maxpy(unsigned int simd, unsigned int ls1, unsigned int ls2,
ger::ger(unsigned int simd, unsigned int ls1, unsigned int ls2,
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch,
binding_policy_t bind):
base_impl<maxpy, maxpy_parameters>(maxpy_parameters(simd, ls1, ls2, ng1, ng2, fetch), bind)
base_impl<ger, ger_parameters>(ger_parameters(simd, ls1, ls2, ng1, ng2, fetch), bind)
{}
std::vector<int_t> maxpy::input_sizes(expressions_tuple const & expressions) const
std::vector<int_t> ger::input_sizes(expressions_tuple const & expressions) const
{
isaac::array_expression const & array_expression = *(expressions.data().front());
std::pair<int_t, int_t> size = matrix_size(lhs_most(array_expression.tree(), array_expression.root()));
return tools::make_vector<int_t>() << size.first << size.second;
}
void maxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base &, controller<expressions_tuple> const & controller)
void ger::enqueue(driver::CommandQueue & /*queue*/, driver::Program & program, const char * suffix, base &, controller<expressions_tuple> const & controller)
{
expressions_tuple const & expressions = controller.x();
char name[32] = {"axpy"};
@@ -129,3 +131,4 @@ void maxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, con
}
}
}

View File

@@ -7,11 +7,11 @@
#include "rapidjson/document.h"
#include "isaac/backend/parse.h"
#include "isaac/backend/templates/vaxpy.h"
#include "isaac/backend/templates/reduction.h"
#include "isaac/backend/templates/maxpy.h"
#include "isaac/backend/templates/mreduction.h"
#include "isaac/backend/templates/mproduct.h"
#include "isaac/backend/templates/axpy.h"
#include "isaac/backend/templates/dot.h"
#include "isaac/backend/templates/ger.h"
#include "isaac/backend/templates/gemv.h"
#include "isaac/backend/templates/gemm.h"
#include "isaac/exception/unknown_datatype.h"
#include "isaac/exception/operation_not_supported.h"
#include "isaac/model/model.h"
@@ -86,12 +86,12 @@ driver::Program& model::init(controller<expressions_tuple> const & expressions)
return *program;
}
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< tools::shared_ptr<base> > const & templates, driver::CommandQueue const & queue) :
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< tools::shared_ptr<templates::base> > const & templates, driver::CommandQueue const & queue) :
templates_(templates), fallback_(fallbacks[std::make_pair(etype, dtype)]), predictor_(new predictors::random_forest(predictor)), queue_(queue)
{}
model::model(expression_type etype, numeric_type dtype, base const & tp, driver::CommandQueue const & queue) : templates_(1,tp.clone()), fallback_(fallbacks[std::make_pair(etype, dtype)]), queue_(queue)
model::model(expression_type etype, numeric_type dtype, templates::base const & tp, driver::CommandQueue const & queue) : templates_(1,tp.clone()), fallback_(fallbacks[std::make_pair(etype, dtype)]), queue_(queue)
{}
void model::execute(controller<expressions_tuple> const & expr)
@@ -148,15 +148,15 @@ namespace detail
{
static expression_type get_expression_type(std::string const & name)
{
if(name=="vaxpy") return VECTOR_AXPY_TYPE;
if(name=="dot") return REDUCTION_TYPE;
if(name=="maxpy") return MATRIX_AXPY_TYPE;
if(name=="mreduction_rows") return ROW_WISE_REDUCTION_TYPE;
if(name=="mreduction_cols") return COL_WISE_REDUCTION_TYPE;
if(name=="mproduct_nn") return MATRIX_PRODUCT_NN_TYPE;
if(name=="mproduct_nt") return MATRIX_PRODUCT_NT_TYPE;
if(name=="mproduct_tn") return MATRIX_PRODUCT_TN_TYPE;
if(name=="mproduct_tt") return MATRIX_PRODUCT_TT_TYPE;
if(name=="axpy") return AXPY_TYPE;
if(name=="dot") return DOT_TYPE;
if(name=="ger") return GER_TYPE;
if(name=="gemv_n") return GEMV_N_TYPE;
if(name=="gemv_t") return GEMV_T_TYPE;
if(name=="gemm_nn") return GEMM_NN_TYPE;
if(name=="gemm_nt") return GEMM_NT_TYPE;
if(name=="gemm_tn") return GEMM_TN_TYPE;
if(name=="gemm_tt") return GEMM_TT_TYPE;
throw std::invalid_argument("Invalid expression: " + name);
}
@@ -167,27 +167,27 @@ namespace detail
throw std::invalid_argument("Invalid datatype: " + name);
}
static tools::shared_ptr<base> create(std::string const & template_name, std::vector<int> const & a)
static tools::shared_ptr<templates::base> create(std::string const & template_name, std::vector<int> const & a)
{
fetching_policy_type fetch[] = {FETCH_FROM_LOCAL, FETCH_FROM_GLOBAL_STRIDED, FETCH_FROM_GLOBAL_CONTIGUOUS};
if(template_name=="vaxpy")
return tools::shared_ptr<base>(new vaxpy(a[0], a[1], a[2], fetch[a[3]]));
templates::fetching_policy_type fetch[] = {templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_GLOBAL_STRIDED, templates::FETCH_FROM_GLOBAL_CONTIGUOUS};
if(template_name=="axpy")
return tools::shared_ptr<templates::base>(new templates::axpy(a[0], a[1], a[2], fetch[a[3]]));
else if(template_name=="dot")
return tools::shared_ptr<base>(new reduction(a[0], a[1], a[2], fetch[a[3]]));
else if(template_name=="maxpy")
return tools::shared_ptr<base>(new maxpy(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("mreduction_rows")!=std::string::npos)
return tools::shared_ptr<base>(new mreduction_rows(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("mreduction_cols")!=std::string::npos)
return tools::shared_ptr<base>(new mreduction_cols(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("mproduct_nn")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("mproduct_tn")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("mproduct_nt")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("mproduct_tt")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
return tools::shared_ptr<templates::base>(new templates::dot(a[0], a[1], a[2], fetch[a[3]]));
else if(template_name=="ger")
return tools::shared_ptr<templates::base>(new templates::ger(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemv_n")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemv_n(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemv_t")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemv_t(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemm_nn")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemm_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemm_tn")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemm_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemm_nt")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemm_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemm_tt")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemm_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else
throw std::invalid_argument("Invalid expression: " + template_name);
}
@@ -207,7 +207,7 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
str.assign((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
document.Parse<0>(str.c_str());
//Deserialize
std::vector<std::string> operations = {"vaxpy", "dot", "maxpy", "gemv_n", "gemv_t", "mproduct_nn", "mproduct_tn", "mproduct_nt", "mproduct_tt"};
std::vector<std::string> operations = {"axpy", "dot", "ger", "gemv_n", "gemv_t", "gemm_nn", "gemm_tn", "gemm_nt", "gemm_tt"};
std::vector<std::string> dtype = {"float32", "float64"};
for(auto & operation : operations)
{
@@ -223,7 +223,7 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
numeric_type dtype = detail::get_dtype(elem);
// Get profiles
std::vector<tools::shared_ptr<base> > templates;
std::vector<tools::shared_ptr<templates::base> > templates;
js::Value const & profiles = document[opcstr][dtcstr]["profiles"];
for (js::SizeType id = 0 ; id < profiles.Size() ; ++id)
templates.push_back(detail::create(operation, tools::to_int_array<int>(profiles[id])));
@@ -243,23 +243,22 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
}
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<base> > init_fallback()
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > init_fallback()
{
typedef tools::shared_ptr<base> ptr_t;
typedef tools::shared_ptr<templates::base> ptr_t;
std::map<std::pair<expression_type, numeric_type>, ptr_t > res;
numeric_type types[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
for(auto DTYPE : types)
{
res[std::make_pair(SCALAR_AXPY_TYPE, DTYPE)] = ptr_t(new vaxpy(1,64,128,FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(VECTOR_AXPY_TYPE, DTYPE)] = ptr_t (new vaxpy(1,64,128,FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(REDUCTION_TYPE, DTYPE)] = ptr_t(new reduction(1,64,128,FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(MATRIX_AXPY_TYPE, DTYPE)] = ptr_t(new maxpy(1,8,8,8,8,FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(ROW_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new mreduction_rows(1, 8, 8, 4, 16, FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(COL_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new mreduction_cols(1, 8, 8, 64, 8, FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(MATRIX_PRODUCT_NN_TYPE, DTYPE)] = ptr_t(new mproduct_nn(1, 8, 32, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TN_TYPE, DTYPE)] = ptr_t(new mproduct_tn(1, 8, 32, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_NT_TYPE, DTYPE)] = ptr_t(new mproduct_nt(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TT_TYPE, DTYPE)] = ptr_t(new mproduct_tt(1, 8, 32, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(AXPY_TYPE, DTYPE)] = ptr_t (new templates::axpy(1,64,128,templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(DOT_TYPE, DTYPE)] = ptr_t(new templates::dot(1,64,128,templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(GER_TYPE, DTYPE)] = ptr_t(new templates::ger(1,8,8,8,8,templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(GEMV_N_TYPE, DTYPE)] = ptr_t(new templates::gemv_n(1, 8, 8, 4, 16, templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(GEMV_T_TYPE, DTYPE)] = ptr_t(new templates::gemv_t(1, 8, 8, 64, 8, templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(GEMM_NN_TYPE, DTYPE)] = ptr_t(new templates::gemm_nn(1, 8, 32, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(GEMM_TN_TYPE, DTYPE)] = ptr_t(new templates::gemm_tn(1, 8, 32, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(GEMM_NT_TYPE, DTYPE)] = ptr_t(new templates::gemm_nt(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(GEMM_TT_TYPE, DTYPE)] = ptr_t(new templates::gemm_tt(1, 8, 32, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
}
return res;
}
@@ -269,7 +268,7 @@ model_map_t init_models(driver::CommandQueue & queue)
{
model_map_t res;
numeric_type dtypes[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
expression_type etypes[] = {SCALAR_AXPY_TYPE, VECTOR_AXPY_TYPE, REDUCTION_TYPE, MATRIX_AXPY_TYPE, ROW_WISE_REDUCTION_TYPE, COL_WISE_REDUCTION_TYPE, MATRIX_PRODUCT_NN_TYPE, MATRIX_PRODUCT_NT_TYPE, MATRIX_PRODUCT_TN_TYPE, MATRIX_PRODUCT_TT_TYPE};
expression_type etypes[] = {AXPY_TYPE, DOT_TYPE, GER_TYPE, GEMV_N_TYPE, GEMV_T_TYPE, GEMM_NN_TYPE, GEMM_NT_TYPE, GEMM_TN_TYPE, GEMM_TT_TYPE};
for(numeric_type dtype: dtypes)
for(expression_type etype: etypes)
@@ -288,7 +287,7 @@ model_map_t& models(driver::CommandQueue & queue)
return it->second;
}
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<base> > fallbacks = init_fallback();
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > fallbacks = init_fallback();
std::map<driver::CommandQueue, model_map_t> models_;
}

View File

@@ -19,13 +19,13 @@ namespace isaac
inline bool is_mmprod(expression_type x)
{
return x==MATRIX_PRODUCT_NN_TYPE || x==MATRIX_PRODUCT_NT_TYPE ||
x==MATRIX_PRODUCT_TN_TYPE || x==MATRIX_PRODUCT_TT_TYPE;
return x==GEMM_NN_TYPE || x==GEMM_NT_TYPE ||
x==GEMM_TN_TYPE || x==GEMM_TT_TYPE;
}
inline bool is_mvprod(expression_type x)
{
return x==ROW_WISE_REDUCTION_TYPE || x==COL_WISE_REDUCTION_TYPE;
return x==GEMV_N_TYPE || x==GEMV_T_TYPE;
}
inline bool has_temporary_impl(op_element op, expression_type expression, expression_type other, bool is_first)
@@ -36,27 +36,27 @@ namespace isaac
case OPERATOR_UNARY_TYPE_FAMILY:
case OPERATOR_BINARY_TYPE_FAMILY:
result |= is_mmprod(expression)
|| (result |= expression==ROW_WISE_REDUCTION_TYPE && other==COL_WISE_REDUCTION_TYPE)
|| (result |= expression==COL_WISE_REDUCTION_TYPE && other==ROW_WISE_REDUCTION_TYPE);
|| (result |= expression==GEMV_N_TYPE && other==GEMV_T_TYPE)
|| (result |= expression==GEMV_T_TYPE && other==GEMV_N_TYPE);
break;
case OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY:
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
result |= is_mvprod(expression)
|| expression==REDUCTION_TYPE;
|| expression==DOT_TYPE;
break;
case OPERATOR_ROWS_REDUCTION_TYPE_FAMILY:
case OPERATOR_ROWS_DOT_TYPE_FAMILY:
result |= is_mmprod(expression)
|| is_mvprod(expression)
|| expression==REDUCTION_TYPE;
|| expression==DOT_TYPE;
break;
case OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY:
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
result |= is_mmprod(expression)
|| is_mvprod(expression)
|| expression==REDUCTION_TYPE;
|| expression==DOT_TYPE;
break;
case OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY:
case OPERATOR_GEMM_TYPE_FAMILY:
result |= (is_mmprod(expression) && !is_first)
|| is_mvprod(expression)
|| expression==REDUCTION_TYPE;
|| expression==DOT_TYPE;
break;
default:
break;
@@ -77,28 +77,28 @@ namespace isaac
{
case OPERATOR_UNARY_TYPE_FAMILY:
if(is_mmprod(left))
return MATRIX_AXPY_TYPE;
return GER_TYPE;
return left;
case OPERATOR_BINARY_TYPE_FAMILY:
if(left == ROW_WISE_REDUCTION_TYPE || right == ROW_WISE_REDUCTION_TYPE) return ROW_WISE_REDUCTION_TYPE;
else if(left == COL_WISE_REDUCTION_TYPE || right == COL_WISE_REDUCTION_TYPE) return COL_WISE_REDUCTION_TYPE;
else if(left == REDUCTION_TYPE || right == REDUCTION_TYPE) return REDUCTION_TYPE;
else if(left == VECTOR_AXPY_TYPE || right == VECTOR_AXPY_TYPE) return op.type==OPERATOR_OUTER_PROD_TYPE?MATRIX_AXPY_TYPE:VECTOR_AXPY_TYPE;
else if(left == MATRIX_AXPY_TYPE || right == MATRIX_AXPY_TYPE) return MATRIX_AXPY_TYPE;
else if(is_mmprod(left) || is_mmprod(right)) return MATRIX_AXPY_TYPE;
if(left == GEMV_N_TYPE || right == GEMV_N_TYPE) return GEMV_N_TYPE;
else if(left == GEMV_T_TYPE || right == GEMV_T_TYPE) return GEMV_T_TYPE;
else if(left == DOT_TYPE || right == DOT_TYPE) return DOT_TYPE;
else if(left == AXPY_TYPE || right == AXPY_TYPE) return op.type==OPERATOR_OUTER_PROD_TYPE?GER_TYPE:AXPY_TYPE;
else if(left == GER_TYPE || right == GER_TYPE) return GER_TYPE;
else if(is_mmprod(left) || is_mmprod(right)) return GER_TYPE;
std::cout << left << " " << right << std::endl;
throw;
case OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY:
return REDUCTION_TYPE;
case OPERATOR_ROWS_REDUCTION_TYPE_FAMILY:
return ROW_WISE_REDUCTION_TYPE;
case OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY:
return COL_WISE_REDUCTION_TYPE;
case OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY:
if(op.type==OPERATOR_MATRIX_PRODUCT_NN_TYPE) return MATRIX_PRODUCT_NN_TYPE;
else if(op.type==OPERATOR_MATRIX_PRODUCT_TN_TYPE) return MATRIX_PRODUCT_TN_TYPE;
else if(op.type==OPERATOR_MATRIX_PRODUCT_NT_TYPE) return MATRIX_PRODUCT_NT_TYPE;
else return MATRIX_PRODUCT_TT_TYPE;
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
return DOT_TYPE;
case OPERATOR_ROWS_DOT_TYPE_FAMILY:
return GEMV_N_TYPE;
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
return GEMV_T_TYPE;
case OPERATOR_GEMM_TYPE_FAMILY:
if(op.type==OPERATOR_GEMM_NN_TYPE) return GEMM_NN_TYPE;
else if(op.type==OPERATOR_GEMM_TN_TYPE) return GEMM_TN_TYPE;
else if(op.type==OPERATOR_GEMM_NT_TYPE) return GEMM_NT_TYPE;
else return GEMM_TT_TYPE;
default:
throw;
}
@@ -119,9 +119,9 @@ namespace isaac
else if(node.lhs.subtype == DENSE_ARRAY_TYPE)
{
if(node.lhs.array->nshape()==1)
type_left = VECTOR_AXPY_TYPE;
type_left = AXPY_TYPE;
else
type_left = MATRIX_AXPY_TYPE;
type_left = GER_TYPE;
}
//Right
@@ -131,9 +131,9 @@ namespace isaac
else if(node.rhs.subtype == DENSE_ARRAY_TYPE)
{
if(node.rhs.array->nshape()==1)
type_right = VECTOR_AXPY_TYPE;
type_right = AXPY_TYPE;
else
type_right = MATRIX_AXPY_TYPE;
type_right = GER_TYPE;
}
@@ -171,12 +171,10 @@ namespace isaac
//Init
expression_type current_type;
if(root_save.lhs.array->nshape()==0)
current_type = SCALAR_AXPY_TYPE;
else if(root_save.lhs.array->nshape()==1)
current_type=VECTOR_AXPY_TYPE;
if(root_save.lhs.array->nshape()<=1)
current_type=AXPY_TYPE;
else
current_type=MATRIX_AXPY_TYPE;
current_type=GER_TYPE;
final_type = current_type;
/*----Parse required temporaries-----*/
@@ -193,18 +191,17 @@ namespace isaac
//Creates temporary
tools::shared_ptr<array> tmp;
switch(it->first){
case SCALAR_AXPY_TYPE:
case REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
case DOT_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
case AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case GEMV_N_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case GEMV_T_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
case GER_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
case GEMM_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
case GEMM_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
case GEMM_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
case GEMM_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
default: throw std::invalid_argument("Unrecognized operation");
}

View File

@@ -12,16 +12,16 @@ namespace preset
void gemm::handle_node(array_expression::container_type &tree, size_t rootidx, args & a)
{
//Matrix-Matrix product node
if(tree[rootidx].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
if(tree[rootidx].op.type_family==OPERATOR_GEMM_TYPE_FAMILY)
{
if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY) a.A = &tree[rootidx].lhs;
if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs;
switch(tree[rootidx].op.type)
{
case OPERATOR_MATRIX_PRODUCT_NN_TYPE: a.type = MATRIX_PRODUCT_NN_TYPE; break;
case OPERATOR_MATRIX_PRODUCT_NT_TYPE: a.type = MATRIX_PRODUCT_NT_TYPE; break;
case OPERATOR_MATRIX_PRODUCT_TN_TYPE: a.type = MATRIX_PRODUCT_TN_TYPE; break;
case OPERATOR_MATRIX_PRODUCT_TT_TYPE: a.type = MATRIX_PRODUCT_TT_TYPE; break;
case OPERATOR_GEMM_NN_TYPE: a.type = GEMM_NN_TYPE; break;
case OPERATOR_GEMM_NT_TYPE: a.type = GEMM_NT_TYPE; break;
case OPERATOR_GEMM_TN_TYPE: a.type = GEMM_TN_TYPE; break;
case OPERATOR_GEMM_TT_TYPE: a.type = GEMM_TT_TYPE; break;
default: break;
}
}
@@ -31,7 +31,7 @@ void gemm::handle_node(array_expression::container_type &tree, size_t rootidx, a
{
//alpha*PROD
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY
&& tree[tree[rootidx].rhs.node_index].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
&& tree[tree[rootidx].rhs.node_index].op.type_family==OPERATOR_GEMM_TYPE_FAMILY)
{
a.alpha = &tree[rootidx].lhs;
handle_node(tree, tree[rootidx].rhs.node_index, a);

View File

@@ -115,7 +115,7 @@ def main():
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
#Source files
src = 'src/lib/symbolic/preset.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/backend/templates/mreduction.cpp src/lib/backend/templates/reduction.cpp src/lib/backend/templates/mproduct.cpp src/lib/backend/templates/maxpy.cpp src/lib/backend/templates/base.cpp src/lib/backend/templates/vaxpy.cpp src/lib/backend/mapped_object.cpp src/lib/backend/stream.cpp src/lib/backend/parse.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/driver/backend.cpp src/lib/driver/device.cpp src/lib/driver/kernel.cpp src/lib/driver/buffer.cpp src/lib/driver/platform.cpp src/lib/driver/check.cpp src/lib/driver/program.cpp src/lib/driver/command_queue.cpp src/lib/driver/context.cpp src/lib/driver/event.cpp src/lib/driver/ndrange.cpp src/lib/driver/handle.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/wrap/clBLAS.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
src = 'src/lib/array.cpp src/lib/value_scalar.cpp src/lib/wrap/clBLAS.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/expression.cpp src/lib/symbolic/io.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/context.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/event.cpp src/lib/driver/device.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/backend/parse.cpp src/lib/backend/mapped_object.cpp src/lib/backend/templates/axpy.cpp src/lib/backend/templates/ger.cpp src/lib/backend/templates/gemm.cpp src/lib/backend/templates/dot.cpp src/lib/backend/templates/gemv.cpp src/lib/backend/templates/base.cpp src/lib/backend/stream.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
boostsrc = 'external/boost/libs/'
for s in ['numpy','python','smart_ptr','system','thread']:
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]

View File

@@ -20,6 +20,6 @@ BOOST_PYTHON_MODULE(_isaac)
export_driver();
export_exceptions();
export_model();
export_templates();
export_core();
}

View File

@@ -70,15 +70,15 @@ namespace tools
else
name = bp::extract<std::string>(odtype.attr("__class__").attr("__name__"))();
if(name=="vaxpy") return isc::VECTOR_AXPY_TYPE;
else if(name=="maxpy") return isc::MATRIX_AXPY_TYPE;
else if(name=="reduction") return isc::REDUCTION_TYPE;
else if(name=="mreduction_rows") return isc::ROW_WISE_REDUCTION_TYPE;
else if(name=="mreduction_cols") return isc::COL_WISE_REDUCTION_TYPE;
else if(name=="mproduct_nn") return isc::MATRIX_PRODUCT_NN_TYPE;
else if(name=="mproduct_tn") return isc::MATRIX_PRODUCT_TN_TYPE;
else if(name=="mproduct_nt") return isc::MATRIX_PRODUCT_NT_TYPE;
else if(name=="mproduct_tt") return isc::MATRIX_PRODUCT_TT_TYPE;
if(name=="axpy") return isc::AXPY_TYPE;
else if(name=="ger") return isc::GER_TYPE;
else if(name=="dot") return isc::DOT_TYPE;
else if(name=="gemv_n") return isc::GEMV_N_TYPE;
else if(name=="gemm_t") return isc::GEMV_T_TYPE;
else if(name=="gemm_nn") return isc::GEMM_NN_TYPE;
else if(name=="gemm_tn") return isc::GEMM_TN_TYPE;
else if(name=="gemm_nt") return isc::GEMM_NT_TYPE;
else if(name=="gemm_tt") return isc::GEMM_TT_TYPE;
else
{
PyErr_SetString(PyExc_TypeError, "Template type not understood");

View File

@@ -1,3 +1,4 @@
#include "isaac/model/model.h"
#include "common.hpp"
#include "core.h"
@@ -79,6 +80,11 @@ unsigned int size(datatype<T> const & dt)
namespace detail
{
std::shared_ptr<isc::model> construct_model(bp::object dtype, bp::object const & tp, isc::driver::CommandQueue & queue)
{
return std::shared_ptr<isc::model>(new isc::model(tools::extract_template_type(tp), tools::extract_dtype(dtype), (isaac::templates::base const &)bp::extract<isaac::templates::base>(tp), queue));
}
std::shared_ptr<isc::array>
ndarray_to_iscarray(const np::ndarray& array, const isc::driver::Context& ctx)
{
@@ -172,7 +178,6 @@ namespace detail
bp::throw_error_already_set();
throw;
}
}
}
@@ -183,6 +188,10 @@ namespace detail
void export_core()
{
bp::class_<isaac::model>("model", bp::no_init)
.def("__init__", bp::make_constructor(detail::construct_model))
.def("execute", &isc::model::execute);
bp::class_<isc::value_scalar>("value_scalar", bp::no_init)
.add_property("dtype", &isc::value_scalar::dtype);
@@ -206,15 +215,15 @@ void export_core()
#undef INSTANTIATE
bp::enum_<isc::expression_type>("operations")
MAP_ENUM(VECTOR_AXPY_TYPE, isc)
MAP_ENUM(MATRIX_AXPY_TYPE, isc)
MAP_ENUM(REDUCTION_TYPE, isc)
MAP_ENUM(ROW_WISE_REDUCTION_TYPE, isc)
MAP_ENUM(COL_WISE_REDUCTION_TYPE, isc)
MAP_ENUM(VECTOR_AXPY_TYPE, isc)
MAP_ENUM(VECTOR_AXPY_TYPE, isc)
MAP_ENUM(VECTOR_AXPY_TYPE, isc)
MAP_ENUM(VECTOR_AXPY_TYPE, isc);
MAP_ENUM(AXPY_TYPE, isc)
MAP_ENUM(GER_TYPE, isc)
MAP_ENUM(DOT_TYPE, isc)
MAP_ENUM(GEMV_N_TYPE, isc)
MAP_ENUM(GEMV_T_TYPE, isc)
MAP_ENUM(GEMM_NN_TYPE, isc)
MAP_ENUM(GEMM_TN_TYPE, isc)
MAP_ENUM(GEMM_NT_TYPE, isc)
MAP_ENUM(GEMM_TT_TYPE, isc);
#define ADD_SCALAR_HANDLING(OP)\
.def(bp::self OP int())\

View File

@@ -1,70 +1,71 @@
#include "isaac/backend/templates/vaxpy.h"
#include "isaac/backend/templates/maxpy.h"
#include "isaac/backend/templates/reduction.h"
#include "isaac/backend/templates/mreduction.h"
#include "isaac/backend/templates/mproduct.h"
#include "isaac/backend/templates/axpy.h"
#include "isaac/backend/templates/ger.h"
#include "isaac/backend/templates/dot.h"
#include "isaac/backend/templates/gemv.h"
#include "isaac/backend/templates/gemm.h"
#include "isaac/model/model.h"
#include "common.hpp"
#include "model.h"
namespace tpt = isaac::templates;
namespace detail
{
bp::list input_sizes(isaac::base & temp, isc::expressions_tuple const & tree)
bp::list input_sizes(tpt::base & temp, isc::expressions_tuple const & tree)
{
std::vector<int> tmp = temp.input_sizes(tree);
return tools::to_list(tmp.begin(), tmp.end());
}
std::shared_ptr<isc::model> construct_model(bp::object dtype, bp::object const & tp, isc::driver::CommandQueue & queue)
{
return std::shared_ptr<isc::model>(new isc::model(tools::extract_template_type(tp), tools::extract_dtype(dtype), (isc::base const &)bp::extract<isc::base>(tp), queue));
}
}
void export_model()
void export_templates()
{
bp::class_<isaac::model>("model", bp::no_init)
.def("__init__", bp::make_constructor(detail::construct_model))
.def("execute", &isc::model::execute);
bp::object templates_module(bp::handle<>(bp::borrowed(PyImport_AddModule("isaac.templates"))));
bp::scope().attr("templates") = templates_module;
bp::scope template_scope = templates_module;
bp::enum_<isaac::fetching_policy_type>
bp::enum_<tpt::fetching_policy_type>
("fetching_policy_type")
.value("FETCH_FROM_LOCAL", isc::FETCH_FROM_LOCAL)
.value("FETCH_FROM_GLOBAL_STRIDED", isc::FETCH_FROM_GLOBAL_STRIDED)
.value("FETCH_FROM_GLOBAL_CONTIGUOUS", isc::FETCH_FROM_GLOBAL_CONTIGUOUS)
.value("FETCH_FROM_LOCAL", tpt::FETCH_FROM_LOCAL)
.value("FETCH_FROM_GLOBAL_STRIDED", tpt::FETCH_FROM_GLOBAL_STRIDED)
.value("FETCH_FROM_GLOBAL_CONTIGUOUS", tpt::FETCH_FROM_GLOBAL_CONTIGUOUS)
;
//Base
{
#define __PROP(name) .def_readonly(#name, &isaac::base::parameters_type::name)
bp::class_<isaac::base, boost::noncopyable>("base", bp::no_init)
.def("lmem_usage", &isaac::base::lmem_usage)
.def("registers_usage", &isaac::base::registers_usage)
.def("is_invalid", &isaac::base::is_invalid)
#define __PROP(name) .def_readonly(#name, &tpt::base::parameters_type::name)
bp::class_<tpt::base, boost::noncopyable>("base", bp::no_init)
.def("lmem_usage", &tpt::base::lmem_usage)
.def("registers_usage", &tpt::base::registers_usage)
.def("is_invalid", &tpt::base::is_invalid)
.def("input_sizes", &detail::input_sizes)
;
#undef __PROP
}
#define WRAP_BASE(name) bp::class_<isaac::base_impl<isaac::name, isaac::name::parameters_type>, bp::bases<isaac::base>, boost::noncopyable>(#name, bp::no_init);
#define WRAP_TEMPLATE(name, basename, ...) bp::class_<isaac::name, bp::bases<isaac::base_impl<isaac::basename, isaac::basename::parameters_type> > >(#name, bp::init<__VA_ARGS__>())\
.add_property("local_size_0", &isc::name::local_size_0)\
.add_property("local_size_1", &isc::name::local_size_1);
#define WRAP_BASE(name) bp::class_<tpt::base_impl<tpt::name, tpt::name::parameters_type>, bp::bases<tpt::base>, boost::noncopyable>(#name, bp::no_init);
#define WRAP_TEMPLATE(name, basename, ...) bp::class_<tpt::name, bp::bases<tpt::base_impl<tpt::basename, tpt::basename::parameters_type> > >(#name, bp::init<__VA_ARGS__>())\
.add_property("local_size_0", &tpt::name::local_size_0)\
.add_property("local_size_1", &tpt::name::local_size_1);
#define WRAP_SINGLE_TEMPLATE(name, ...) WRAP_BASE(name) WRAP_TEMPLATE(name, name, __VA_ARGS__)
//Vector AXPY
WRAP_SINGLE_TEMPLATE(vaxpy, uint, uint, uint, isaac::fetching_policy_type)
WRAP_SINGLE_TEMPLATE(maxpy, uint, uint, uint, uint, uint, isaac::fetching_policy_type)
WRAP_SINGLE_TEMPLATE(reduction, uint, uint, uint, isaac::fetching_policy_type)
WRAP_BASE(mreduction)
WRAP_TEMPLATE(mreduction_rows, mreduction, uint, uint, uint, uint, uint, isaac::fetching_policy_type)
WRAP_TEMPLATE(mreduction_cols, mreduction, uint, uint, uint, uint, uint, isaac::fetching_policy_type)
WRAP_BASE(mproduct)
WRAP_TEMPLATE(mproduct_nn, mproduct, uint, uint, uint, uint, uint, uint, uint, uint, isaac::fetching_policy_type, isaac::fetching_policy_type, uint, uint)
WRAP_TEMPLATE(mproduct_tn, mproduct, uint, uint, uint, uint, uint, uint, uint, uint, isaac::fetching_policy_type, isaac::fetching_policy_type, uint, uint)
WRAP_TEMPLATE(mproduct_nt, mproduct, uint, uint, uint, uint, uint, uint, uint, uint, isaac::fetching_policy_type, isaac::fetching_policy_type, uint, uint)
WRAP_TEMPLATE(mproduct_tt, mproduct, uint, uint, uint, uint, uint, uint, uint, uint, isaac::fetching_policy_type, isaac::fetching_policy_type, uint, uint)
WRAP_SINGLE_TEMPLATE(axpy, uint, uint, uint, tpt::fetching_policy_type)
WRAP_SINGLE_TEMPLATE(ger, uint, uint, uint, uint, uint, tpt::fetching_policy_type)
WRAP_SINGLE_TEMPLATE(dot, uint, uint, uint, tpt::fetching_policy_type)
WRAP_BASE(gemv)
WRAP_TEMPLATE(gemv_n, gemv, uint, uint, uint, uint, uint, tpt::fetching_policy_type)
WRAP_TEMPLATE(gemv_t, gemv, uint, uint, uint, uint, uint, tpt::fetching_policy_type)
WRAP_BASE(gemm)
WRAP_TEMPLATE(gemm_nn, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint)
WRAP_TEMPLATE(gemm_tn, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint)
WRAP_TEMPLATE(gemm_nt, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint)
WRAP_TEMPLATE(gemm_tt, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint)
}

View File

@@ -1,6 +1,6 @@
#ifndef ISAAC_PYTHON_MODEL_HPP
#define ISAAC_PYTHON_MODEL_HPP
void export_model();
void export_templates();
#endif

View File

@@ -4,7 +4,7 @@ if(CUDA_FOUND)
endif()
get_property(ISAAC_PATH TARGET isaac PROPERTY LOCATION)
foreach(PROG maxpy vaxpy reduction mreduction mproduct)
foreach(PROG axpy dot ger gemv gemm)
add_executable(${PROG}-test ${PROG}.cpp)
add_test(${PROG} ${PROG}-test)
target_link_libraries(${PROG}-test isaac ${OPENCL_LIBRARIES})

View File

@@ -14,10 +14,10 @@ from numpy import cumsum
import tools
fetch_types = [isc.fetching_policy_type.FETCH_FROM_LOCAL,
isc.fetching_policy_type.FETCH_FROM_LOCAL,
isc.fetching_policy_type.FETCH_FROM_LOCAL,
isc.fetching_policy_type.FETCH_FROM_LOCAL]
fetch_types = [isc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_CONTIGUOUS,
isc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_STRIDED,
isc.templates.fetching_policy_type.FETCH_FROM_LOCAL,
isc.templates.fetching_policy_type.FETCH_FROM_LOCAL]
def exhaustive(template, sizes, context):
tree, _ = tools.tree_of(template, sizes, context)
@@ -159,15 +159,15 @@ def is_local_optimum(parameters, template, sizes, context):
tree, _ = tools.tree_of(template, sizes, context)
genetic_infos = tools.genetic_infos_of(template)
if issubclass(template, isc.vaxpy):
if issubclass(template, isc.axpy):
sweep_over = [0,1,2]
elif issubclass(template, isc.reduction):
elif issubclass(template, isc.dot):
sweep_over = [0,1,2]
elif issubclass(template, isc.maxpy):
elif issubclass(template, isc.ger):
sweep_over = [0,1,2,3,4]
elif issubclass(template, isc.mreduction):
elif issubclass(template, isc.gemv):
sweep_over = [0,1,2,3,4]
elif issubclass(template, isc.mproduct):
elif issubclass(template, isc.gemm):
sweep_over = [1,2,3,4,5,7,10,11]
#Evaluate the provided parameters guess

View File

@@ -36,30 +36,30 @@ def benchmark(template, setting, tree):
def tree_of(template, sizes, context):
if issubclass(template, isc.vaxpy):
if issubclass(template, isc.templates.axpy):
N, = sizes
x = isc.empty(N, dtype=isc.float32, context=context)
y = isc.empty(N, dtype=isc.float32, context=context)
return x + y, (x, y)
elif issubclass(template, isc.reduction):
elif issubclass(template, isc.templates.dot):
N, = sizes
x = isc.empty(N, context=context)
y = isc.empty(N, context=context)
return isc.dot(x, y), (x, y)
elif issubclass(template, isc.maxpy):
elif issubclass(template, isc.templates.ger):
M, N = sizes
A = isc.empty((M,N), context=context)
B = isc.empty((M,N), context=context)
return A + B, (A, B)
elif issubclass(template, isc.mreduction):
T = template is isc.mreduction_cols
elif issubclass(template, isc.templates.gemv):
T = template is isc.templates.gemv_t
M, N = sizes[::-1] if T else sizes
A = isc.empty((M,N), context=context)
x = isc.empty(N, context=context)
return isc.dot(A.T, x) if T else isc.dot(A, x), (A, x)
elif issubclass(template, isc.mproduct):
AT = template is isc.mproduct_tn or template is isc.mproduct_tt
BT = template is isc.mproduct_nt or template is isc.mproduct_tt
elif issubclass(template, isc.templates.gemm):
AT = template is isc.templates.gemm_tn or template is isc.templates.gemm_tt
BT = template is isc.templates.gemm_nt or template is isc.templates.gemm_tt
M, N, K = sizes
A = isc.empty((K, M) if AT else (M, K), context=context)
B = isc.empty((N, K) if BT else (K, N), context=context)
@@ -68,35 +68,35 @@ def tree_of(template, sizes, context):
return isc.dot(AA, BB), (A, B)
def memory_footprint(template, sizes):
if issubclass(template, isc.vaxpy):
if issubclass(template, isc.templates.axpy):
return 4*3*sizes[0]*1e-9
elif issubclass(template, isc.reduction):
elif issubclass(template, isc.templates.dot):
return 4*2*sizes[0]*1e-9
elif issubclass(template, isc.maxpy):
elif issubclass(template, isc.templates.ger):
return 4*3*sizes[0]*sizes[1]*1e-9
elif issubclass(template, isc.mreduction):
elif issubclass(template, isc.templates.gemv):
return 4*sizes[0]*sizes[1]*1e-9
elif issubclass(template, isc.mproduct):
elif issubclass(template, isc.templates.gemm):
return 4*(sizes[0]*sizes[1] + sizes[0]*sizes[2] + sizes[1]*sizes[2])*1e-9
def metric_of(template):
memory_bound = [isc.vaxpy, isc.reduction, isc.maxpy, isc.mreduction]
compute_bound = [isc.mproduct]
memory_bound = [isc.templates.axpy, isc.templates.dot, isc.templates.ger, isc.templates.gemv]
compute_bound = [isc.templates.gemm]
if any([issubclass(template, x) for x in memory_bound]):
return lambda sizes, t: memory_footprint(template, sizes)/t
elif any([issubclass(template, x) for x in compute_bound]):
return lambda sizes, t: 2*sizes[0]*sizes[1]*sizes[2]*1e-9/t
def genetic_infos_of(template):
if issubclass(template, isc.vaxpy):
if issubclass(template, isc.templates.axpy):
return {'categorical': [3], 'nbits': [3,4,4,2] }
elif issubclass(template, isc.reduction):
elif issubclass(template, isc.templates.dot):
return {'categorical': [3], 'nbits':[3,4,4,2]}
elif issubclass(template, isc.maxpy):
elif issubclass(template, isc.templates.ger):
return {'categorical': [5], 'nbits': [3,3,3,3,4,2]}
elif issubclass(template, isc.mreduction):
elif issubclass(template, isc.templates.gemv):
return {'categorical': [5], 'nbits': [3,3,3,3,4,2]}
elif issubclass(template, isc.mproduct):
elif issubclass(template, isc.templates.gemm):
return {'categorical': [8,9], 'nbits': [3,3,3,3,3,2,2,2,2,2,3,3]}

View File

@@ -23,14 +23,14 @@ def tune(device, operation, json_path):
#List of size tuples to use
sizes = {}
sizes[isc.vaxpy] = [(x,) for x in tools.expspace(1e3, 1e7, 4)]
sizes[isc.mreduction_rows] = product(pow2range(4,17), pow2range(4,17))
sizes[isc.mreduction_cols] = isc.mreduction_rows
sizes[isc.mproduct_nn] = product(pow2range(5, 10), pow2range(5, 10), pow2range(5, 10))
sizes[isc.mproduct_nn] = [(128, 169, 1728)]
sizes[isc.mproduct_tn] = sizes[isc.mproduct_nn]
sizes[isc.mproduct_nt] = sizes[isc.mproduct_nn]
sizes[isc.mproduct_tt] = sizes[isc.mproduct_nn]
sizes[isc.templates.axpy] = [(x,) for x in tools.expspace(1e3, 1e7, 4)]
sizes[isc.templates.gemv_n] = product(pow2range(4,17), pow2range(4,17))
sizes[isc.templates.gemv_t] = sizes[isc.templates.gemv_n]
sizes[isc.templates.gemm_nn] = product(pow2range(5, 10), pow2range(5, 10), pow2range(5, 10))
sizes[isc.templates.gemm_nn] = [(128, 169, 1728)]
sizes[isc.templates.gemm_tn] = sizes[isc.templates.gemm_nn]
sizes[isc.templates.gemm_nt] = sizes[isc.templates.gemm_nn]
sizes[isc.templates.gemm_tt] = sizes[isc.templates.gemm_nn]
sizes = unique(list(sizes[operation]))
sizes = [x for x in sizes if 1e-4 <= tools.memory_footprint(operation, x) <= 1e-1]
@@ -123,9 +123,9 @@ def parse_arguments():
print("----------------")
operation = {'vaxpy': isc.vaxpy, 'dot': isc.reduction,
'maxpy': isc.maxpy, 'gemv_n': isc.mreduction_rows, 'gemv_t': isc.mreduction_cols,
'gemm_nn': isc.mproduct_nn, 'gemm_tn': isc.mproduct_tn, 'gemm_nt': isc.mproduct_nt, 'gemm_tt':isc.mproduct_tt}[args.operation]
operation = {'axpy': isc.templates.axpy, 'dot': isc.templates.dot,
'ger': isc.templates.ger, 'gemv_n': isc.templates.gemv_n, 'gemv_t': isc.templates.gemv_t,
'gemm_nn': isc.templates.gemm_nn, 'gemm_tn': isc.templates.gemm_tn, 'gemm_nt': isc.templates.gemm_nt, 'gemm_tt':isc.templates.gemm_tt}[args.operation]
if not args.json:
json = tools.sanitize(device.name) + '.json'
return (device, operation, json)