More cleaning
This commit is contained in:
@@ -26,6 +26,7 @@
|
|||||||
#include <list>
|
#include <list>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
#include "isaac/types.h"
|
#include "isaac/types.h"
|
||||||
#include "isaac/jit/generation/engine/stream.h"
|
#include "isaac/jit/generation/engine/stream.h"
|
||||||
@@ -75,25 +76,24 @@ class base
|
|||||||
public:
|
public:
|
||||||
struct parameters_type
|
struct parameters_type
|
||||||
{
|
{
|
||||||
parameters_type(unsigned int _vwidth, int_t _ls0, int_t _ls1, int_t _num_kernels);
|
parameters_type(uint32_t _vwidth, int_t _ls0, int_t _ls1, int_t _nkernels);
|
||||||
unsigned int vwidth;
|
uint32_t vwidth;
|
||||||
unsigned int ls0;
|
uint32_t ls0;
|
||||||
unsigned int ls1;
|
uint32_t ls1;
|
||||||
unsigned int num_kernels;
|
uint32_t nkernels;
|
||||||
};
|
};
|
||||||
private:
|
private:
|
||||||
virtual std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const = 0;
|
virtual std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const = 0;
|
||||||
public:
|
public:
|
||||||
base();
|
base();
|
||||||
virtual unsigned int temporary_workspace(expression_tree const &) const;
|
|
||||||
virtual unsigned int lmem_usage(expression_tree const &) const;
|
|
||||||
virtual unsigned int registers_usage(expression_tree const &) const;
|
|
||||||
virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const = 0;
|
|
||||||
virtual ~base();
|
virtual ~base();
|
||||||
std::string generate(std::string const & suffix, expression_tree const & expressions, driver::Device const & device);
|
virtual uint32_t temporary_workspace(expression_tree const &) const;
|
||||||
|
virtual uint32_t lmem_usage(expression_tree const &) const;
|
||||||
|
virtual uint32_t registers_usage(expression_tree const &) const;
|
||||||
|
virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const = 0;
|
||||||
virtual int is_invalid(expression_tree const & expressions, driver::Device const & device) const = 0;
|
virtual int is_invalid(expression_tree const & expressions, driver::Device const & device) const = 0;
|
||||||
virtual void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & expressions) = 0;
|
virtual void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & expressions) = 0;
|
||||||
virtual std::shared_ptr<base> clone() const = 0;
|
std::string generate(std::string const & suffix, expression_tree const & expressions, driver::Device const & device);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@@ -105,9 +105,8 @@ private:
|
|||||||
public:
|
public:
|
||||||
typedef ParametersType parameters_type;
|
typedef ParametersType parameters_type;
|
||||||
base_impl(parameters_type const & parameters);
|
base_impl(parameters_type const & parameters);
|
||||||
unsigned int ls0() const;
|
uint32_t ls0() const;
|
||||||
unsigned int ls1() const;
|
uint32_t ls1() const;
|
||||||
std::shared_ptr<base> clone() const;
|
|
||||||
/** @brief returns whether or not the profile has undefined behavior on particular device */
|
/** @brief returns whether or not the profile has undefined behavior on particular device */
|
||||||
int is_invalid(expression_tree const & expressions, driver::Device const & device) const;
|
int is_invalid(expression_tree const & expressions, driver::Device const & device) const;
|
||||||
protected:
|
protected:
|
||||||
|
@@ -32,8 +32,8 @@ namespace templates
|
|||||||
class elementwise_1d_parameters : public base::parameters_type
|
class elementwise_1d_parameters : public base::parameters_type
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
elementwise_1d_parameters(unsigned int _vwidth, unsigned int _group_size, unsigned int _num_groups, fetch_type _fetch);
|
elementwise_1d_parameters(uint32_t _vwidth, uint32_t _group_size, uint32_t _ng, fetch_type _fetch);
|
||||||
unsigned int num_groups;
|
uint32_t ng;
|
||||||
fetch_type fetch;
|
fetch_type fetch;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ private:
|
|||||||
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & symbols) const;
|
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & symbols) const;
|
||||||
public:
|
public:
|
||||||
elementwise_1d(elementwise_1d::parameters_type const & parameters);
|
elementwise_1d(elementwise_1d::parameters_type const & parameters);
|
||||||
elementwise_1d(unsigned int _vwidth, unsigned int _group_size, unsigned int _num_groups, fetch_type _fetch);
|
elementwise_1d(uint32_t _vwidth, uint32_t _group_size, uint32_t _ng, fetch_type _fetch);
|
||||||
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||||
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
|
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
|
||||||
};
|
};
|
||||||
|
@@ -33,10 +33,10 @@ namespace templates
|
|||||||
class elementwise_2d_parameters : public base::parameters_type
|
class elementwise_2d_parameters : public base::parameters_type
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
elementwise_2d_parameters(unsigned int _vwidth, unsigned int _ls0, unsigned int _ls1, unsigned int _num_groups_0, unsigned int _num_groups_1, fetch_type _fetch);
|
elementwise_2d_parameters(uint32_t _vwidth, uint32_t _ls0, uint32_t _ls1, uint32_t _ng0, uint32_t _ng1, fetch_type _fetch);
|
||||||
|
|
||||||
unsigned int num_groups_0;
|
uint32_t ng0;
|
||||||
unsigned int num_groups_1;
|
uint32_t ng1;
|
||||||
fetch_type fetch;
|
fetch_type fetch;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ private:
|
|||||||
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const;
|
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const;
|
||||||
public:
|
public:
|
||||||
elementwise_2d(parameters_type const & parameters);
|
elementwise_2d(parameters_type const & parameters);
|
||||||
elementwise_2d(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetch_type fetch);
|
elementwise_2d(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2, fetch_type fetch);
|
||||||
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||||
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
|
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
|
||||||
};
|
};
|
||||||
|
@@ -80,9 +80,9 @@ ADD_KEYWORD(GroupIdx0, "get_group_id(0)", "blockIdx.x")
|
|||||||
ADD_KEYWORD(GroupIdx1, "get_group_id(1)", "blockIdx.y")
|
ADD_KEYWORD(GroupIdx1, "get_group_id(1)", "blockIdx.y")
|
||||||
ADD_KEYWORD(GroupIdx2, "get_group_id(2)", "blockIdx.z")
|
ADD_KEYWORD(GroupIdx2, "get_group_id(2)", "blockIdx.z")
|
||||||
|
|
||||||
ADD_KEYWORD(GroupSize0, "get_num_groups(0)", "GridDim.x")
|
ADD_KEYWORD(GroupSize0, "get_ng(0)", "GridDim.x")
|
||||||
ADD_KEYWORD(GroupSize1, "get_num_groups(1)", "GridDim.y")
|
ADD_KEYWORD(GroupSize1, "get_ng(1)", "GridDim.y")
|
||||||
ADD_KEYWORD(GroupSize2, "get_num_groups(2)", "GridDim.z")
|
ADD_KEYWORD(GroupSize2, "get_ng(2)", "GridDim.z")
|
||||||
|
|
||||||
ADD_KEYWORD(LocalBarrier, "barrier(CLK_LOCAL_MEM_FENCE)", "__syncthreads()")
|
ADD_KEYWORD(LocalBarrier, "barrier(CLK_LOCAL_MEM_FENCE)", "__syncthreads()")
|
||||||
struct CastPrefix: public keyword{ CastPrefix(driver::backend_type backend, std::string const & datatype): keyword(backend, "convert_" + datatype, "make_" + datatype){} };
|
struct CastPrefix: public keyword{ CastPrefix(driver::backend_type backend, std::string const & datatype): keyword(backend, "convert_" + datatype, "make_" + datatype){} };
|
||||||
|
@@ -33,12 +33,12 @@ class kernel_generation_stream : public std::ostream
|
|||||||
class kgenstream : public std::stringbuf
|
class kgenstream : public std::stringbuf
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
kgenstream(std::ostringstream& oss,unsigned int const & tab_count) ;
|
kgenstream(std::ostringstream& oss,uint32_t const & tab_count) ;
|
||||||
int sync();
|
int sync();
|
||||||
~kgenstream();
|
~kgenstream();
|
||||||
private:
|
private:
|
||||||
std::ostream& oss_;
|
std::ostream& oss_;
|
||||||
unsigned int const & tab_count_;
|
uint32_t const & tab_count_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void process(std::string& str);
|
void process(std::string& str);
|
||||||
@@ -51,7 +51,7 @@ public:
|
|||||||
void inc_tab();
|
void inc_tab();
|
||||||
void dec_tab();
|
void dec_tab();
|
||||||
private:
|
private:
|
||||||
unsigned int tab_count_;
|
uint32_t tab_count_;
|
||||||
driver::backend_type backend_;
|
driver::backend_type backend_;
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
};
|
};
|
||||||
|
@@ -33,27 +33,27 @@ namespace templates
|
|||||||
|
|
||||||
struct gemm_parameters : public base::parameters_type
|
struct gemm_parameters : public base::parameters_type
|
||||||
{
|
{
|
||||||
gemm_parameters(unsigned int vwidth
|
gemm_parameters(uint32_t vwidth
|
||||||
, unsigned int ls0, unsigned int KL, unsigned int ls1, unsigned int D
|
,uint32_t ls0, uint32_t KL, uint32_t ls1, uint32_t D
|
||||||
, unsigned int ms, unsigned int ks, unsigned int ns
|
,uint32_t ms, uint32_t ks, uint32_t ns
|
||||||
, fetch_type Afetch, fetch_type Bfetch
|
,fetch_type Afetch, fetch_type Bfetch
|
||||||
, unsigned int lf0, unsigned int lf1);
|
,uint32_t lf0, uint32_t lf1);
|
||||||
|
|
||||||
unsigned int kL;
|
uint32_t kL;
|
||||||
unsigned int depth;
|
uint32_t depth;
|
||||||
|
|
||||||
unsigned int mS;
|
uint32_t mS;
|
||||||
unsigned int kS;
|
uint32_t kS;
|
||||||
unsigned int nS;
|
uint32_t nS;
|
||||||
|
|
||||||
fetch_type Afetch;
|
fetch_type Afetch;
|
||||||
fetch_type Bfetch;
|
fetch_type Bfetch;
|
||||||
|
|
||||||
unsigned int lf0;
|
uint32_t lf0;
|
||||||
unsigned int lf1;
|
uint32_t lf1;
|
||||||
|
|
||||||
unsigned int mL;
|
uint32_t mL;
|
||||||
unsigned int nL;
|
uint32_t nL;
|
||||||
|
|
||||||
bool prefetch;
|
bool prefetch;
|
||||||
bool unroll_outer;
|
bool unroll_outer;
|
||||||
@@ -62,9 +62,9 @@ struct gemm_parameters : public base::parameters_type
|
|||||||
class gemm : public base_impl<gemm, gemm_parameters>
|
class gemm : public base_impl<gemm, gemm_parameters>
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
unsigned int temporary_workspace(expression_tree const & expressions) const;
|
uint32_t temporary_workspace(expression_tree const & expressions) const;
|
||||||
unsigned int lmem_usage(expression_tree const & expressions) const;
|
uint32_t lmem_usage(expression_tree const & expressions) const;
|
||||||
unsigned int registers_usage(expression_tree const & expressions) const;
|
uint32_t registers_usage(expression_tree const & expressions) const;
|
||||||
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
||||||
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const &) const;
|
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const &) const;
|
||||||
void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, const expression_tree::node &A, const expression_tree::node &B, const expression_tree::node &C,
|
void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, const expression_tree::node &A, const expression_tree::node &B, const expression_tree::node &C,
|
||||||
@@ -83,7 +83,7 @@ private:
|
|||||||
class gemm_nn : public gemm
|
class gemm_nn : public gemm
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
gemm_nn(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
gemm_nn(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
|
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
|
||||||
, int_t lf0, int_t lf1);
|
, int_t lf0, int_t lf1);
|
||||||
};
|
};
|
||||||
@@ -91,7 +91,7 @@ public:
|
|||||||
class gemm_tn : public gemm
|
class gemm_tn : public gemm
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
gemm_tn(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
gemm_tn(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
|
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
|
||||||
, int_t lf0, int_t lf1);
|
, int_t lf0, int_t lf1);
|
||||||
};
|
};
|
||||||
@@ -100,7 +100,7 @@ public:
|
|||||||
class gemm_nt : public gemm
|
class gemm_nt : public gemm
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
gemm_nt(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
gemm_nt(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
|
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
|
||||||
, int_t lf0, int_t lf1);
|
, int_t lf0, int_t lf1);
|
||||||
};
|
};
|
||||||
@@ -109,7 +109,7 @@ public:
|
|||||||
class gemm_tt : public gemm
|
class gemm_tt : public gemm
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
gemm_tt(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
gemm_tt(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
|
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
|
||||||
, int_t lf0, int_t lf1);
|
, int_t lf0, int_t lf1);
|
||||||
};
|
};
|
||||||
|
@@ -31,26 +31,26 @@ namespace templates
|
|||||||
|
|
||||||
struct reduce_1d_parameters : public base::parameters_type
|
struct reduce_1d_parameters : public base::parameters_type
|
||||||
{
|
{
|
||||||
reduce_1d_parameters(unsigned int _vwidth,
|
reduce_1d_parameters(uint32_t _vwidth,
|
||||||
unsigned int _group_size, unsigned int _num_groups,
|
uint32_t _group_size, uint32_t _ng,
|
||||||
fetch_type _fetch);
|
fetch_type _fetch);
|
||||||
unsigned int num_groups;
|
uint32_t ng;
|
||||||
fetch_type fetch;
|
fetch_type fetch;
|
||||||
};
|
};
|
||||||
|
|
||||||
class reduce_1d : public base_impl<reduce_1d, reduce_1d_parameters>
|
class reduce_1d : public base_impl<reduce_1d, reduce_1d_parameters>
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
unsigned int lmem_usage(expression_tree const & expressions) const;
|
uint32_t lmem_usage(expression_tree const & expressions) const;
|
||||||
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
||||||
unsigned int temporary_workspace(expression_tree const & expressions) const;
|
uint32_t temporary_workspace(expression_tree const & expressions) const;
|
||||||
inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<symbolic::reduce_1d*> exprs,
|
inline void reduce_1d_local_memory(kernel_generation_stream & stream, uint32_t size, std::vector<symbolic::reduce_1d*> exprs,
|
||||||
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const;
|
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const;
|
||||||
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const;
|
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
reduce_1d(reduce_1d::parameters_type const & parameters);
|
reduce_1d(reduce_1d::parameters_type const & parameters);
|
||||||
reduce_1d(unsigned int simd, unsigned int ls, unsigned int ng, fetch_type fetch);
|
reduce_1d(uint32_t simd, uint32_t ls, uint32_t ng, fetch_type fetch);
|
||||||
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||||
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
|
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
|
||||||
private:
|
private:
|
||||||
|
@@ -33,11 +33,11 @@ namespace templates
|
|||||||
{
|
{
|
||||||
struct reduce_2d_parameters : public base::parameters_type
|
struct reduce_2d_parameters : public base::parameters_type
|
||||||
{
|
{
|
||||||
reduce_2d_parameters(unsigned int _vwidth,
|
reduce_2d_parameters(uint32_t _vwidth,
|
||||||
unsigned int _ls0, unsigned int _ls1,
|
uint32_t _ls0, uint32_t _ls1,
|
||||||
unsigned int _num_groups_0, unsigned int _num_groups_1, fetch_type _fetch_policy);
|
uint32_t _ng0, uint32_t _ng1, fetch_type _fetch_policy);
|
||||||
unsigned int num_groups_0;
|
uint32_t ng0;
|
||||||
unsigned int num_groups_1;
|
uint32_t ng1;
|
||||||
fetch_type fetch_policy;
|
fetch_type fetch_policy;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -48,8 +48,8 @@ protected:
|
|||||||
reduce_2d(reduce_2d::parameters_type const & , operation_type_family);
|
reduce_2d(reduce_2d::parameters_type const & , operation_type_family);
|
||||||
private:
|
private:
|
||||||
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
||||||
unsigned int lmem_usage(expression_tree const &) const;
|
uint32_t lmem_usage(expression_tree const &) const;
|
||||||
unsigned int temporary_workspace(expression_tree const & expressions) const;
|
uint32_t temporary_workspace(expression_tree const & expressions) const;
|
||||||
std::string generate_impl(std::string const & suffix, expression_tree const &, driver::Device const & device, symbolic::symbols_table const &) const;
|
std::string generate_impl(std::string const & suffix, expression_tree const &, driver::Device const & device, symbolic::symbols_table const &) const;
|
||||||
public:
|
public:
|
||||||
virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||||
@@ -62,14 +62,14 @@ class reduce_2d_rows : public reduce_2d
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
reduce_2d_rows(reduce_2d::parameters_type const &);
|
reduce_2d_rows(reduce_2d::parameters_type const &);
|
||||||
reduce_2d_rows(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetch_type fetch);
|
reduce_2d_rows(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2, fetch_type fetch);
|
||||||
};
|
};
|
||||||
|
|
||||||
class reduce_2d_cols : public reduce_2d
|
class reduce_2d_cols : public reduce_2d
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
reduce_2d_cols(reduce_2d::parameters_type const &);
|
reduce_2d_cols(reduce_2d::parameters_type const &);
|
||||||
reduce_2d_cols(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetch_type fetch);
|
reduce_2d_cols(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2, fetch_type fetch);
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -90,7 +90,7 @@ std::vector<size_t> rhs_of(expression_tree const & tree, std::vector<size_t> con
|
|||||||
std::string hash(expression_tree const & tree);
|
std::string hash(expression_tree const & tree);
|
||||||
|
|
||||||
//Set arguments
|
//Set arguments
|
||||||
void set_arguments(expression_tree const & tree, driver::Kernel & kernel, unsigned int & current_arg);
|
void set_arguments(expression_tree const & tree, driver::Kernel & kernel, uint32_t& current_arg);
|
||||||
|
|
||||||
//Symbolize
|
//Symbolize
|
||||||
symbols_table symbolize(isaac::expression_tree const & expression);
|
symbols_table symbolize(isaac::expression_tree const & expression);
|
||||||
|
@@ -53,7 +53,7 @@ public:
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
value_type(expression_type, numeric_type, predictors::random_forest const &, std::vector< std::shared_ptr<templates::base> > const &, driver::CommandQueue const &);
|
value_type(expression_type, numeric_type, predictors::random_forest const &, std::vector< std::shared_ptr<templates::base> > const &, driver::CommandQueue const &);
|
||||||
value_type(expression_type, numeric_type, templates::base const &, driver::CommandQueue const &);
|
value_type(expression_type, numeric_type, std::shared_ptr<templates::base> const &, driver::CommandQueue const &);
|
||||||
void execute(runtime::execution_handler const &);
|
void execute(runtime::execution_handler const &);
|
||||||
templates_container const & templates() const;
|
templates_container const & templates() const;
|
||||||
|
|
||||||
|
@@ -40,24 +40,23 @@ namespace isaac
|
|||||||
namespace templates
|
namespace templates
|
||||||
{
|
{
|
||||||
|
|
||||||
base::parameters_type::parameters_type(unsigned int _vwidth, int_t _ls0, int_t _ls1, int_t _num_kernels) : vwidth(_vwidth), ls0(_ls0), ls1(_ls1), num_kernels(_num_kernels)
|
base::parameters_type::parameters_type(uint32_t _vwidth, int_t _ls0, int_t _ls1, int_t _nkernels) : vwidth(_vwidth), ls0(_ls0), ls1(_ls1), nkernels(_nkernels)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
base::base()
|
base::base()
|
||||||
{}
|
{}
|
||||||
|
|
||||||
unsigned int base::lmem_usage(expression_tree const &) const
|
uint32_t base::lmem_usage(expression_tree const &) const
|
||||||
{ return 0; }
|
{ return 0; }
|
||||||
|
|
||||||
unsigned int base::registers_usage(expression_tree const &) const
|
uint32_t base::registers_usage(expression_tree const &) const
|
||||||
{ return 0; }
|
{ return 0; }
|
||||||
|
|
||||||
unsigned int base::temporary_workspace(expression_tree const &) const
|
uint32_t base::temporary_workspace(expression_tree const &) const
|
||||||
{ return 0; }
|
{ return 0; }
|
||||||
|
|
||||||
base::~base()
|
base::~base()
|
||||||
{
|
{ }
|
||||||
}
|
|
||||||
|
|
||||||
std::string base::generate(std::string const & suffix, expression_tree const & expression, driver::Device const & device)
|
std::string base::generate(std::string const & suffix, expression_tree const & expression, driver::Device const & device)
|
||||||
{
|
{
|
||||||
@@ -79,17 +78,13 @@ base_impl<TType, PType>::base_impl(parameters_type const & parameters) : base(),
|
|||||||
{ }
|
{ }
|
||||||
|
|
||||||
template<class TType, class PType>
|
template<class TType, class PType>
|
||||||
unsigned int base_impl<TType, PType>::ls0() const
|
uint32_t base_impl<TType, PType>::ls0() const
|
||||||
{ return p_.ls0; }
|
{ return p_.ls0; }
|
||||||
|
|
||||||
template<class TType, class PType>
|
template<class TType, class PType>
|
||||||
unsigned int base_impl<TType, PType>::ls1() const
|
uint32_t base_impl<TType, PType>::ls1() const
|
||||||
{ return p_.ls1; }
|
{ return p_.ls1; }
|
||||||
|
|
||||||
template<class TType, class PType>
|
|
||||||
std::shared_ptr<base> base_impl<TType, PType>::clone() const
|
|
||||||
{ return std::shared_ptr<base>(new TType(*dynamic_cast<TType const *>(this))); }
|
|
||||||
|
|
||||||
template<class TType, class PType>
|
template<class TType, class PType>
|
||||||
int base_impl<TType, PType>::is_invalid(expression_tree const & expressions, driver::Device const & device) const
|
int base_impl<TType, PType>::is_invalid(expression_tree const & expressions, driver::Device const & device) const
|
||||||
{
|
{
|
||||||
|
@@ -36,10 +36,10 @@ namespace isaac
|
|||||||
namespace templates
|
namespace templates
|
||||||
{
|
{
|
||||||
|
|
||||||
elementwise_1d_parameters::elementwise_1d_parameters(unsigned int _vwidth,
|
elementwise_1d_parameters::elementwise_1d_parameters(uint32_t _vwidth,
|
||||||
unsigned int _group_size, unsigned int _num_groups,
|
uint32_t _group_size, uint32_t _ng,
|
||||||
fetch_type _fetch) :
|
fetch_type _fetch) :
|
||||||
base::parameters_type(_vwidth, _group_size, 1, 1), num_groups(_num_groups), fetch(_fetch)
|
base::parameters_type(_vwidth, _group_size, 1, 1), ng(_ng), fetch(_fetch)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
|
|||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
}
|
}
|
||||||
|
|
||||||
element_wise_loop_1D(stream, p_.fetch, p_.vwidth, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](unsigned int vwidth)
|
element_wise_loop_1D(stream, p_.fetch, p_.vwidth, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t vwidth)
|
||||||
{
|
{
|
||||||
std::string dtype = append_width("#scalartype",vwidth);
|
std::string dtype = append_width("#scalartype",vwidth);
|
||||||
|
|
||||||
@@ -97,12 +97,12 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
|
|||||||
|
|
||||||
//Compute
|
//Compute
|
||||||
for(size_t idx: assignments)
|
for(size_t idx: assignments)
|
||||||
for(unsigned int s = 0 ; s < vwidth ; ++s)
|
for(uint32_t s = 0 ; s < vwidth ; ++s)
|
||||||
stream << symbols.at(idx)->evaluate({{"leaf", access_vector_type("#name", s, vwidth)}}) << ";" << std::endl;
|
stream << symbols.at(idx)->evaluate({{"leaf", access_vector_type("#name", s, vwidth)}}) << ";" << std::endl;
|
||||||
|
|
||||||
//Writes back
|
//Writes back
|
||||||
for(symbolic::leaf* sym: symbolic::extract<symbolic::leaf>(tree, symbols, assignments_lhs, false))
|
for(symbolic::leaf* sym: symbolic::extract<symbolic::leaf>(tree, symbols, assignments_lhs, false))
|
||||||
for(unsigned int s = 0 ; s < vwidth ; ++s)
|
for(uint32_t s = 0 ; s < vwidth ; ++s)
|
||||||
stream << sym->process("at(i+" + tools::to_string(s)+") = " + access_vector_type("#name", s, vwidth) + ";") << std::endl;
|
stream << sym->process("at(i+" + tools::to_string(s)+") = " + access_vector_type("#name", s, vwidth) + ";") << std::endl;
|
||||||
});
|
});
|
||||||
//Close user-provided for-loops
|
//Close user-provided for-loops
|
||||||
@@ -122,7 +122,7 @@ elementwise_1d::elementwise_1d(elementwise_1d_parameters const & parameters) :
|
|||||||
base_impl<elementwise_1d, elementwise_1d_parameters>(parameters)
|
base_impl<elementwise_1d, elementwise_1d_parameters>(parameters)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
elementwise_1d::elementwise_1d(unsigned int simd, unsigned int ls, unsigned int ng,
|
elementwise_1d::elementwise_1d(uint32_t simd, uint32_t ls, uint32_t ng,
|
||||||
fetch_type fetch):
|
fetch_type fetch):
|
||||||
base_impl<elementwise_1d, elementwise_1d_parameters>(elementwise_1d_parameters(simd,ls,ng,fetch))
|
base_impl<elementwise_1d, elementwise_1d_parameters>(elementwise_1d_parameters(simd,ls,ng,fetch))
|
||||||
{}
|
{}
|
||||||
@@ -143,10 +143,10 @@ void elementwise_1d::enqueue(driver::CommandQueue &, driver::Program const & pro
|
|||||||
name += suffix;
|
name += suffix;
|
||||||
driver::Kernel kernel(program, name.c_str());
|
driver::Kernel kernel(program, name.c_str());
|
||||||
//NDRange
|
//NDRange
|
||||||
driver::NDRange global(p_.ls0*p_.num_groups);
|
driver::NDRange global(p_.ls0*p_.ng);
|
||||||
driver::NDRange local(p_.ls0);
|
driver::NDRange local(p_.ls0);
|
||||||
//Arguments
|
//Arguments
|
||||||
unsigned int current_arg = 0;
|
uint32_t current_arg = 0;
|
||||||
kernel.setSizeArg(current_arg++, size);
|
kernel.setSizeArg(current_arg++, size);
|
||||||
symbolic::set_arguments(expressions, kernel, current_arg);
|
symbolic::set_arguments(expressions, kernel, current_arg);
|
||||||
control.execution_options().enqueue(program.context(), kernel, global, local);
|
control.execution_options().enqueue(program.context(), kernel, global, local);
|
||||||
|
@@ -33,10 +33,10 @@ namespace isaac
|
|||||||
namespace templates
|
namespace templates
|
||||||
{
|
{
|
||||||
|
|
||||||
elementwise_2d_parameters::elementwise_2d_parameters(unsigned int _vwidth,
|
elementwise_2d_parameters::elementwise_2d_parameters(uint32_t _vwidth,
|
||||||
unsigned int _ls0, unsigned int _ls1,
|
uint32_t _ls0, uint32_t _ls1,
|
||||||
unsigned int _num_groups_0, unsigned int _num_groups_1,
|
uint32_t _ng0, uint32_t _ng1,
|
||||||
fetch_type _fetch) : base::parameters_type(_vwidth, _ls0, _ls1, 1), num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetch(_fetch){ }
|
fetch_type _fetch) : base::parameters_type(_vwidth, _ls0, _ls1, 1), ng0(_ng0), ng1(_ng1), fetch(_fetch){ }
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -114,8 +114,8 @@ std::string elementwise_2d::generate_impl(std::string const & suffix, expression
|
|||||||
elementwise_2d::elementwise_2d(parameters_type const & parameters) :
|
elementwise_2d::elementwise_2d(parameters_type const & parameters) :
|
||||||
base_impl<elementwise_2d, elementwise_2d_parameters>(parameters){ }
|
base_impl<elementwise_2d, elementwise_2d_parameters>(parameters){ }
|
||||||
|
|
||||||
elementwise_2d::elementwise_2d(unsigned int simd, unsigned int ls1, unsigned int ls2,
|
elementwise_2d::elementwise_2d(uint32_t simd, uint32_t ls1, uint32_t ls2,
|
||||||
unsigned int ng1, unsigned int ng2, fetch_type fetch):
|
uint32_t ng1, uint32_t ng2, fetch_type fetch):
|
||||||
base_impl<elementwise_2d, elementwise_2d_parameters>(elementwise_2d_parameters(simd, ls1, ls2, ng1, ng2, fetch))
|
base_impl<elementwise_2d, elementwise_2d_parameters>(elementwise_2d_parameters(simd, ls1, ls2, ng1, ng2, fetch))
|
||||||
{}
|
{}
|
||||||
|
|
||||||
@@ -129,9 +129,9 @@ void elementwise_2d::enqueue(driver::CommandQueue & /*queue*/, driver::Program c
|
|||||||
std::string name = "elementwise_2d";
|
std::string name = "elementwise_2d";
|
||||||
name +=suffix;
|
name +=suffix;
|
||||||
driver::Kernel kernel(program, name.c_str());
|
driver::Kernel kernel(program, name.c_str());
|
||||||
driver::NDRange global(p_.ls0*p_.num_groups_0, p_.ls1*p_.num_groups_1);
|
driver::NDRange global(p_.ls0*p_.ng0, p_.ls1*p_.ng1);
|
||||||
driver::NDRange local(p_.ls0, p_.ls1);
|
driver::NDRange local(p_.ls0, p_.ls1);
|
||||||
unsigned int current_arg = 0;
|
uint32_t current_arg = 0;
|
||||||
std::vector<int_t> MN = input_sizes(expressions);
|
std::vector<int_t> MN = input_sizes(expressions);
|
||||||
kernel.setSizeArg(current_arg++, MN[0]);
|
kernel.setSizeArg(current_arg++, MN[0]);
|
||||||
kernel.setSizeArg(current_arg++, MN[1]);
|
kernel.setSizeArg(current_arg++, MN[1]);
|
||||||
|
@@ -25,14 +25,14 @@
|
|||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
|
||||||
kernel_generation_stream::kgenstream::kgenstream(std::ostringstream& oss,unsigned int const & tab_count) :
|
kernel_generation_stream::kgenstream::kgenstream(std::ostringstream& oss,uint32_t const & tab_count) :
|
||||||
oss_(oss), tab_count_(tab_count)
|
oss_(oss), tab_count_(tab_count)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
|
|
||||||
int kernel_generation_stream::kgenstream::sync()
|
int kernel_generation_stream::kgenstream::sync()
|
||||||
{
|
{
|
||||||
for (unsigned int i=0; i<tab_count_;++i)
|
for (uint32_t i=0; i<tab_count_;++i)
|
||||||
oss_ << " ";
|
oss_ << " ";
|
||||||
std::string next = str();
|
std::string next = str();
|
||||||
oss_ << next;
|
oss_ << next;
|
||||||
@@ -69,9 +69,9 @@ ADD_KEYWORD("GROUP_IDX_0", "get_group_id(0)", "blockIdx.x")
|
|||||||
ADD_KEYWORD("GROUP_IDX_1", "get_group_id(1)", "blockIdx.y")
|
ADD_KEYWORD("GROUP_IDX_1", "get_group_id(1)", "blockIdx.y")
|
||||||
ADD_KEYWORD("GROUP_IDX_2", "get_group_id(2)", "blockIdx.z")
|
ADD_KEYWORD("GROUP_IDX_2", "get_group_id(2)", "blockIdx.z")
|
||||||
|
|
||||||
ADD_KEYWORD("GROUP_SIZE_0", "get_num_groups(0)", "GridDim.x")
|
ADD_KEYWORD("GROUP_SIZE_0", "get_ng(0)", "GridDim.x")
|
||||||
ADD_KEYWORD("GROUP_SIZE_1", "get_num_groups(1)", "GridDim.y")
|
ADD_KEYWORD("GROUP_SIZE_1", "get_ng(1)", "GridDim.y")
|
||||||
ADD_KEYWORD("GROUP_SIZE_2", "get_num_groups(2)", "GridDim.z")
|
ADD_KEYWORD("GROUP_SIZE_2", "get_ng(2)", "GridDim.z")
|
||||||
|
|
||||||
ADD_KEYWORD("LOCAL_BARRIER", "barrier(CLK_LOCAL_MEM_FENCE)", "__syncthreads()")
|
ADD_KEYWORD("LOCAL_BARRIER", "barrier(CLK_LOCAL_MEM_FENCE)", "__syncthreads()")
|
||||||
ADD_KEYWORD("LOCAL_PTR", "__local", "")
|
ADD_KEYWORD("LOCAL_PTR", "__local", "")
|
||||||
|
@@ -37,11 +37,11 @@ namespace isaac
|
|||||||
namespace templates
|
namespace templates
|
||||||
{
|
{
|
||||||
|
|
||||||
gemm_parameters::gemm_parameters(unsigned int vwidth
|
gemm_parameters::gemm_parameters(uint32_t vwidth
|
||||||
, unsigned int ls0, unsigned int KL, unsigned int ls1, unsigned int D
|
,uint32_t ls0, uint32_t KL, uint32_t ls1, uint32_t D
|
||||||
, unsigned int ms, unsigned int ks, unsigned int ns
|
,uint32_t ms, uint32_t ks, uint32_t ns
|
||||||
, fetch_type Afetch, fetch_type Bfetch
|
,fetch_type Afetch, fetch_type Bfetch
|
||||||
, unsigned int lf0, unsigned int lf1): base::parameters_type(vwidth, ls0, ls1, 1),
|
,uint32_t lf0, uint32_t lf1): base::parameters_type(vwidth, ls0, ls1, 1),
|
||||||
kL(KL), depth(D), mS(ms), kS(ks), nS(ns), Afetch(Afetch), Bfetch(Bfetch),
|
kL(KL), depth(D), mS(ms), kS(ks), nS(ns), Afetch(Afetch), Bfetch(Bfetch),
|
||||||
lf0(lf0), lf1(lf1),
|
lf0(lf0), lf1(lf1),
|
||||||
mL(ms*ls0), nL(ns*ls1)
|
mL(ms*ls0), nL(ns*ls1)
|
||||||
@@ -49,9 +49,9 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
unsigned int gemm::lmem_usage(expression_tree const & expression) const
|
uint32_t gemm::lmem_usage(expression_tree const & expression) const
|
||||||
{
|
{
|
||||||
unsigned int N = 0;
|
uint32_t N = 0;
|
||||||
size_t llda = (A_trans_=='N')?p_.mL:p_.kL+1;
|
size_t llda = (A_trans_=='N')?p_.mL:p_.kL+1;
|
||||||
size_t lnda = (A_trans_=='N')?p_.kL:p_.mL;
|
size_t lnda = (A_trans_=='N')?p_.kL:p_.mL;
|
||||||
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL+1;
|
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL+1;
|
||||||
@@ -61,13 +61,13 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
return N*size_of(expression.dtype());
|
return N*size_of(expression.dtype());
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned int gemm::registers_usage(expression_tree const & expression) const
|
uint32_t gemm::registers_usage(expression_tree const & expression) const
|
||||||
{
|
{
|
||||||
unsigned int N = p_.mS * p_.nS + p_.mS * p_.kS + p_.kS * p_.nS;
|
uint32_t N = p_.mS * p_.nS + p_.mS * p_.kS + p_.kS * p_.nS;
|
||||||
return N*size_of(expression.dtype());
|
return N*size_of(expression.dtype());
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned int gemm::temporary_workspace(expression_tree const & expressions) const
|
uint32_t gemm::temporary_workspace(expression_tree const & expressions) const
|
||||||
{
|
{
|
||||||
std::vector<int_t> MNK = input_sizes(expressions);
|
std::vector<int_t> MNK = input_sizes(expressions);
|
||||||
int_t M = MNK[0]; int_t N = MNK[1];
|
int_t M = MNK[0]; int_t N = MNK[1];
|
||||||
@@ -97,8 +97,8 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
|
|
||||||
if (p_.Afetch==FETCH_FROM_LOCAL)
|
if (p_.Afetch==FETCH_FROM_LOCAL)
|
||||||
{
|
{
|
||||||
unsigned int bound1 = (A_trans_=='N')?p_.kL:p_.mL;
|
uint32_t bound1 = (A_trans_=='N')?p_.kL:p_.mL;
|
||||||
unsigned int bound0 = (A_trans_=='N')?p_.mL:p_.kL;
|
uint32_t bound0 = (A_trans_=='N')?p_.mL:p_.kL;
|
||||||
|
|
||||||
if (p_.lf1>0 && (bound1 % p_.lf1)> 0)
|
if (p_.lf1>0 && (bound1 % p_.lf1)> 0)
|
||||||
return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
|
return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
|
||||||
@@ -109,8 +109,8 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
}
|
}
|
||||||
if (p_.Bfetch==FETCH_FROM_LOCAL)
|
if (p_.Bfetch==FETCH_FROM_LOCAL)
|
||||||
{
|
{
|
||||||
unsigned int bound1 = (B_trans_=='T')?p_.kL:p_.nL;
|
uint32_t bound1 = (B_trans_=='T')?p_.kL:p_.nL;
|
||||||
unsigned int bound0 = (B_trans_=='T')?p_.nL:p_.kL;
|
uint32_t bound0 = (B_trans_=='T')?p_.nL:p_.kL;
|
||||||
|
|
||||||
if (p_.lf1>0 && (bound1 % p_.lf1)> 0)
|
if (p_.lf1>0 && (bound1 % p_.lf1)> 0)
|
||||||
return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
|
return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
|
||||||
@@ -190,8 +190,8 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
size_t lndb = (B_trans_=='T')?p_.kL:p_.nL;
|
size_t lndb = (B_trans_=='T')?p_.kL:p_.nL;
|
||||||
stream << "$LOCAL " << sdtype << " lA[" << llda*lnda << "];" << std::endl;
|
stream << "$LOCAL " << sdtype << " lA[" << llda*lnda << "];" << std::endl;
|
||||||
stream << "$LOCAL " << sdtype << " lB[" << lldb*lndb << "];" << std::endl;
|
stream << "$LOCAL " << sdtype << " lB[" << lldb*lndb << "];" << std::endl;
|
||||||
unsigned int npA = p_.mL/(A_trans_=='N'?p_.lf0*p_.vwidth:p_.lf1);
|
uint32_t npA = p_.mL/(A_trans_=='N'?p_.lf0*p_.vwidth:p_.lf1);
|
||||||
unsigned int npB = p_.nL/(B_trans_=='T'?p_.lf0*p_.vwidth:p_.lf1);
|
uint32_t npB = p_.nL/(B_trans_=='T'?p_.lf0*p_.vwidth:p_.lf1);
|
||||||
stream << "$GLOBAL " << sdtype << "* Ai[" << npA << "];" << std::endl;
|
stream << "$GLOBAL " << sdtype << "* Ai[" << npA << "];" << std::endl;
|
||||||
stream << "$GLOBAL " << sdtype << "* Bi[" << npB << "];" << std::endl;
|
stream << "$GLOBAL " << sdtype << "* Bi[" << npB << "];" << std::endl;
|
||||||
stream << std::endl;
|
stream << std::endl;
|
||||||
@@ -290,13 +290,13 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
stream << std::endl;
|
stream << std::endl;
|
||||||
|
|
||||||
for(unsigned int i = 0 ; i < npA ; i++ )
|
for(uint32_t i = 0 ; i < npA ; i++ )
|
||||||
if (A_trans_=='N')
|
if (A_trans_=='N')
|
||||||
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.lf0*p_.vwidth) + " < M", "(int)((idT.x + " + to_string(i*p_.lf0*p_.vwidth) + ")" + ASTRIDE1 + ")", "0") << ";" << std::endl;
|
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.lf0*p_.vwidth) + " < M", "(int)((idT.x + " + to_string(i*p_.lf0*p_.vwidth) + ")" + ASTRIDE1 + ")", "0") << ";" << std::endl;
|
||||||
else
|
else
|
||||||
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.lf1) + " < M", "(int)((idT.y + " + to_string(i*p_.lf1) + ")*lda)", "0") << ";" << std::endl;
|
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.lf1) + " < M", "(int)((idT.y + " + to_string(i*p_.lf1) + ")*lda)", "0") << ";" << std::endl;
|
||||||
|
|
||||||
for(unsigned int i = 0 ; i < npB ; i++ )
|
for(uint32_t i = 0 ; i < npB ; i++ )
|
||||||
if (B_trans_=='T')
|
if (B_trans_=='T')
|
||||||
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.lf0*p_.vwidth) + " < N", "(int)((idT.x + " + to_string(i*p_.lf0*p_.vwidth) + ")" + BSTRIDE1 + ")", "0") << ";" << std::endl;
|
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.lf0*p_.vwidth) + " < N", "(int)((idT.x + " + to_string(i*p_.lf0*p_.vwidth) + ")" + BSTRIDE1 + ")", "0") << ";" << std::endl;
|
||||||
else
|
else
|
||||||
@@ -318,13 +318,13 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
stream << "//Fetch A to local memory" << std::endl;
|
stream << "//Fetch A to local memory" << std::endl;
|
||||||
if (A_trans_=='N')
|
if (A_trans_=='N')
|
||||||
{
|
{
|
||||||
for(unsigned int k = 0; k < p_.kL; k += p_.lf1)
|
for(uint32_t k = 0; k < p_.kL; k += p_.lf1)
|
||||||
for(unsigned int m = 0; m < p_.mL; m += p_.lf0*p_.vwidth)
|
for(uint32_t m = 0; m < p_.mL; m += p_.lf0*p_.vwidth)
|
||||||
{
|
{
|
||||||
std::string mm = to_string(m/(p_.vwidth*p_.lf0));
|
std::string mm = to_string(m/(p_.vwidth*p_.lf0));
|
||||||
std::string kk = to_string(k);
|
std::string kk = to_string(k);
|
||||||
if(last_iteration)
|
if(last_iteration)
|
||||||
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
|
||||||
stream << "ldsA[" << k*llda + m + s << "] = (condy" << k << " && " << s << "< M)? Ai[" << mm << "][" << k << "*lda + " << s << "] : 0;" << std::endl;
|
stream << "ldsA[" << k*llda + m + s << "] = (condy" << k << " && " << s << "< M)? Ai[" << mm << "][" << k << "*lda + " << s << "] : 0;" << std::endl;
|
||||||
else
|
else
|
||||||
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Ai[" + mm +"][" + kk + "*lda]"), "0", "ldsA + " + to_string(k*llda+m)) << ";" << std::endl;
|
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Ai[" + mm +"][" + kk + "*lda]"), "0", "ldsA + " + to_string(k*llda+m)) << ";" << std::endl;
|
||||||
@@ -332,13 +332,13 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
for(unsigned int k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
|
for(uint32_t k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
|
||||||
for(unsigned int m = 0; m < p_.mL; m += p_.lf1)
|
for(uint32_t m = 0; m < p_.mL; m += p_.lf1)
|
||||||
{
|
{
|
||||||
std::string mm = to_string(m/p_.lf1);
|
std::string mm = to_string(m/p_.lf1);
|
||||||
std::string kk = to_string(k);
|
std::string kk = to_string(k);
|
||||||
if(last_iteration)
|
if(last_iteration)
|
||||||
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
|
||||||
stream << "ldsA[" << m*llda + k + s << "] = condx" << k + s << "? Ai[" << mm << "][" << k + s << ASTRIDE1 << "] : 0;" << std::endl;
|
stream << "ldsA[" << m*llda + k + s << "] = condx" << k + s << "? Ai[" << mm << "][" << k + s << ASTRIDE1 << "] : 0;" << std::endl;
|
||||||
|
|
||||||
else
|
else
|
||||||
@@ -349,13 +349,13 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
stream << "//Fetch B to local memory" << std::endl;
|
stream << "//Fetch B to local memory" << std::endl;
|
||||||
if (B_trans_=='T')
|
if (B_trans_=='T')
|
||||||
{
|
{
|
||||||
for(unsigned int k = 0; k < p_.kL; k += p_.lf1)
|
for(uint32_t k = 0; k < p_.kL; k += p_.lf1)
|
||||||
for(unsigned int n = 0; n < p_.nL; n += p_.lf0*p_.vwidth)
|
for(uint32_t n = 0; n < p_.nL; n += p_.lf0*p_.vwidth)
|
||||||
{
|
{
|
||||||
std::string nn = to_string(n/(p_.vwidth*p_.lf0));
|
std::string nn = to_string(n/(p_.vwidth*p_.lf0));
|
||||||
std::string kk = to_string(k);
|
std::string kk = to_string(k);
|
||||||
if(last_iteration)
|
if(last_iteration)
|
||||||
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
|
||||||
stream << "ldsB[" << k*lldb + n + s << "] = (condy" << k << " && " << s << "< N)? Bi[" << nn << "][" << kk << "*ldb +" << s << "] : 0;" << std::endl;
|
stream << "ldsB[" << k*lldb + n + s << "] = (condy" << k << " && " << s << "< N)? Bi[" << nn << "][" << kk << "*ldb +" << s << "] : 0;" << std::endl;
|
||||||
else
|
else
|
||||||
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Bi[" + nn +"][" + kk + "*ldb]"), "0", "ldsB + " + to_string(k*lldb+n)) << ";" << std::endl;
|
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Bi[" + nn +"][" + kk + "*ldb]"), "0", "ldsB + " + to_string(k*lldb+n)) << ";" << std::endl;
|
||||||
@@ -363,13 +363,13 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
for(unsigned int k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
|
for(uint32_t k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
|
||||||
for(unsigned int n = 0; n < p_.nL; n += p_.lf1)
|
for(uint32_t n = 0; n < p_.nL; n += p_.lf1)
|
||||||
{
|
{
|
||||||
std::string nn = to_string(n/p_.lf1);
|
std::string nn = to_string(n/p_.lf1);
|
||||||
std::string kk = to_string(k);
|
std::string kk = to_string(k);
|
||||||
if(last_iteration)
|
if(last_iteration)
|
||||||
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
|
||||||
stream << "ldsB[" << n*lldb + k + s << "] = condx" << k + s << "? Bi[" << nn << "][" << k + s << BSTRIDE1 << "] : 0;" << std::endl;
|
stream << "ldsB[" << n*lldb + k + s << "] = condx" << k + s << "? Bi[" << nn << "][" << k + s << BSTRIDE1 << "] : 0;" << std::endl;
|
||||||
|
|
||||||
else
|
else
|
||||||
@@ -391,14 +391,14 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
std::string bound = last_iteration?"K":tools::to_string(p_.kL);
|
std::string bound = last_iteration?"K":tools::to_string(p_.kL);
|
||||||
size_t ks = last_iteration?1:p_.kS;
|
size_t ks = last_iteration?1:p_.kS;
|
||||||
stream << "//Inner loop" << std::endl;
|
stream << "//Inner loop" << std::endl;
|
||||||
stream << "for(unsigned int k = 0; k < " << bound << "; k+=" << ks << "){" << std::endl;
|
stream << "for(uint32_t k = 0; k < " << bound << "; k+=" << ks << "){" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
|
|
||||||
stream << "//Fetch A to registers" << std::endl;
|
stream << "//Fetch A to registers" << std::endl;
|
||||||
stream << "#pragma unroll" << std::endl;
|
stream << "#pragma unroll" << std::endl;
|
||||||
stream << "for(unsigned int kk = 0; kk < " << ks << "; kk++)" << std::endl;
|
stream << "for(uint32_t kk = 0; kk < " << ks << "; kk++)" << std::endl;
|
||||||
stream << "#pragma unroll " << p_.mS/p_.vwidth << std::endl;
|
stream << "#pragma unroll " << p_.mS/p_.vwidth << std::endl;
|
||||||
stream << "for(unsigned int mm = 0; mm < " << p_.mS/p_.vwidth << "; mm++)" << std::endl;
|
stream << "for(uint32_t mm = 0; mm < " << p_.mS/p_.vwidth << "; mm++)" << std::endl;
|
||||||
stream << "{" << std::endl;
|
stream << "{" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
if(A_trans_=='N')
|
if(A_trans_=='N')
|
||||||
@@ -408,7 +408,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
if(p_.vwidth==1)
|
if(p_.vwidth==1)
|
||||||
stream << "rA[kk][mm] = ldsA[k + mm*" << p_.ls0*llda << "+ kk" << "];" << std::endl;
|
stream << "rA[kk][mm] = ldsA[k + mm*" << p_.ls0*llda << "+ kk" << "];" << std::endl;
|
||||||
else
|
else
|
||||||
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
|
||||||
stream << access_vector_type("rA[kk][mm]", s) << " = ldsA[k + (mm*" << p_.vwidth*p_.ls0 << " + " << s << ")*" << llda << "+ kk];" << std::endl;
|
stream << access_vector_type("rA[kk][mm]", s) << " = ldsA[k + (mm*" << p_.vwidth*p_.ls0 << " + " << s << ")*" << llda << "+ kk];" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -417,9 +417,9 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
|
|
||||||
stream << "//Fetch B to registers" << std::endl;
|
stream << "//Fetch B to registers" << std::endl;
|
||||||
stream << "#pragma unroll " << ks << std::endl;
|
stream << "#pragma unroll " << ks << std::endl;
|
||||||
stream << "for(unsigned int kk = 0; kk < " << ks << "; kk++)" << std::endl;
|
stream << "for(uint32_t kk = 0; kk < " << ks << "; kk++)" << std::endl;
|
||||||
stream << "#pragma unroll " << p_.nS/p_.vwidth << std::endl;
|
stream << "#pragma unroll " << p_.nS/p_.vwidth << std::endl;
|
||||||
stream << "for(unsigned int nn = 0; nn < " << p_.nS/p_.vwidth << "; nn++)" << std::endl;
|
stream << "for(uint32_t nn = 0; nn < " << p_.nS/p_.vwidth << "; nn++)" << std::endl;
|
||||||
stream << "{" << std::endl;
|
stream << "{" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
if(B_trans_=='T')
|
if(B_trans_=='T')
|
||||||
@@ -429,7 +429,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
if(p_.vwidth==1)
|
if(p_.vwidth==1)
|
||||||
stream << "rB[kk][nn] = ldsB[k" << " + nn*" << p_.ls1*lldb << "+ kk" << "];" << std::endl;
|
stream << "rB[kk][nn] = ldsB[k" << " + nn*" << p_.ls1*lldb << "+ kk" << "];" << std::endl;
|
||||||
else
|
else
|
||||||
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
|
||||||
stream << access_vector_type("rB[kk][nn]", s) << " = ldsB[k" << " + (nn*" << p_.vwidth*p_.ls1 << " + " << s << ")*" << lldb << "+ kk];" << std::endl;
|
stream << access_vector_type("rB[kk][nn]", s) << " = ldsB[k" << " + (nn*" << p_.vwidth*p_.ls1 << " + " << s << ")*" << lldb << "+ kk];" << std::endl;
|
||||||
}
|
}
|
||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
@@ -437,10 +437,10 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
|
|
||||||
stream << "//FMA computations" << std::endl;
|
stream << "//FMA computations" << std::endl;
|
||||||
stream << "#pragma unroll" << std::endl;
|
stream << "#pragma unroll" << std::endl;
|
||||||
stream << "for(unsigned int kk = 0 ; kk < " << ks << "; ++kk){" << std::endl;
|
stream << "for(uint32_t kk = 0 ; kk < " << ks << "; ++kk){" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
for(unsigned int nn=0; nn < p_.nS; ++nn)
|
for(uint32_t nn=0; nn < p_.nS; ++nn)
|
||||||
for(unsigned int mm=0; mm < p_.mS; ++mm){
|
for(uint32_t mm=0; mm < p_.mS; ++mm){
|
||||||
string res_str, lhs_str, rhs_str;
|
string res_str, lhs_str, rhs_str;
|
||||||
res_str = "rC[" + to_string(mm) + "][" + to_string(nn) + "]";
|
res_str = "rC[" + to_string(mm) + "][" + to_string(nn) + "]";
|
||||||
if (p_.vwidth==1)
|
if (p_.vwidth==1)
|
||||||
@@ -461,18 +461,18 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
|
|
||||||
//Increment A pointers to global memory
|
//Increment A pointers to global memory
|
||||||
if (A_trans_=='N')
|
if (A_trans_=='N')
|
||||||
for(unsigned int i = 0 ; i < npA ; ++i)
|
for(uint32_t i = 0 ; i < npA ; ++i)
|
||||||
stream << "Ai[" << i << "] += " << p_.kL << "*lda;" << std::endl;
|
stream << "Ai[" << i << "] += " << p_.kL << "*lda;" << std::endl;
|
||||||
else
|
else
|
||||||
for(unsigned int i = 0 ; i < npA ; ++i)
|
for(uint32_t i = 0 ; i < npA ; ++i)
|
||||||
stream << "Ai[" << i << "] += " << p_.kL << ASTRIDE1 << ";" << std::endl;
|
stream << "Ai[" << i << "] += " << p_.kL << ASTRIDE1 << ";" << std::endl;
|
||||||
|
|
||||||
//Increment B pointers to global memory
|
//Increment B pointers to global memory
|
||||||
if (B_trans_=='T')
|
if (B_trans_=='T')
|
||||||
for(unsigned int i = 0 ; i < npB ; ++i)
|
for(uint32_t i = 0 ; i < npB ; ++i)
|
||||||
stream << "Bi[" << i << "] += " << p_.kL << "*ldb;" << std::endl;
|
stream << "Bi[" << i << "] += " << p_.kL << "*ldb;" << std::endl;
|
||||||
else
|
else
|
||||||
for(unsigned int i = 0 ; i < npB ; ++i)
|
for(uint32_t i = 0 ; i < npB ; ++i)
|
||||||
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
|
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
|
||||||
};
|
};
|
||||||
fetch_to_lds(false);
|
fetch_to_lds(false);
|
||||||
@@ -483,15 +483,15 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
if(A_trans_=='N' || B_trans_=='T')
|
if(A_trans_=='N' || B_trans_=='T')
|
||||||
{
|
{
|
||||||
stream << "int Ky = K - idT.y;" << std::endl;
|
stream << "int Ky = K - idT.y;" << std::endl;
|
||||||
for(unsigned int k = 0; k < p_.kL; k += p_.lf1)
|
for(uint32_t k = 0; k < p_.kL; k += p_.lf1)
|
||||||
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
|
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(A_trans_=='T' || B_trans_=='N')
|
if(A_trans_=='T' || B_trans_=='N')
|
||||||
{
|
{
|
||||||
stream << "int Kx = K - idT.x;" << std::endl;
|
stream << "int Kx = K - idT.x;" << std::endl;
|
||||||
for(unsigned int k = 0 ; k < p_.kL ; k += p_.lf0*p_.vwidth)
|
for(uint32_t k = 0 ; k < p_.kL ; k += p_.lf0*p_.vwidth)
|
||||||
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
|
||||||
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
|
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
|
||||||
}
|
}
|
||||||
fetch_to_lds(true);
|
fetch_to_lds(true);
|
||||||
@@ -522,13 +522,13 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
stream << "N -= ids.y;" << std::endl;
|
stream << "N -= ids.y;" << std::endl;
|
||||||
stream << "N -= ids.w*" << p_.vwidth << ";" << std::endl;
|
stream << "N -= ids.w*" << p_.vwidth << ";" << std::endl;
|
||||||
|
|
||||||
for(unsigned int n=0; n < p_.nS; ++n)
|
for(uint32_t n=0; n < p_.nS; ++n)
|
||||||
{
|
{
|
||||||
string Cj = to_string((n/p_.vwidth)*(p_.ls1*p_.vwidth) + n%p_.vwidth);
|
string Cj = to_string((n/p_.vwidth)*(p_.ls1*p_.vwidth) + n%p_.vwidth);
|
||||||
stream << "if(" << Cj << " >= N) return;" << std::endl;
|
stream << "if(" << Cj << " >= N) return;" << std::endl;
|
||||||
for(unsigned int m=0; m < p_.mS; ++m)
|
for(uint32_t m=0; m < p_.mS; ++m)
|
||||||
stream << "rC[" << m << "][" << n << "] *= alpha;" << std::endl;
|
stream << "rC[" << m << "][" << n << "] *= alpha;" << std::endl;
|
||||||
for(unsigned int m=0; m < p_.mS; ++m)
|
for(uint32_t m=0; m < p_.mS; ++m)
|
||||||
{
|
{
|
||||||
string Ci = to_string((m/p_.vwidth)*(p_.ls0*p_.vwidth) + m%p_.vwidth);
|
string Ci = to_string((m/p_.vwidth)*(p_.ls0*p_.vwidth) + m%p_.vwidth);
|
||||||
stream << "if(" << Ci << "< M) ";
|
stream << "if(" << Ci << "< M) ";
|
||||||
@@ -560,14 +560,14 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
|
|
||||||
stream << "C += Cstart;" << std::endl;
|
stream << "C += Cstart;" << std::endl;
|
||||||
stream << "for(unsigned int i = $GLOBAL_IDX_0 ; i < M ; i += $GLOBAL_SIZE_0)" << std::endl;
|
stream << "for(uint32_t i = $GLOBAL_IDX_0 ; i < M ; i += $GLOBAL_SIZE_0)" << std::endl;
|
||||||
stream << "{" << std::endl;
|
stream << "{" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
stream << "for(unsigned int j = $GLOBAL_IDX_1 ; j < N ; j += $GLOBAL_SIZE_1)" << std::endl;
|
stream << "for(uint32_t j = $GLOBAL_IDX_1 ; j < N ; j += $GLOBAL_SIZE_1)" << std::endl;
|
||||||
stream << "{" << std::endl;
|
stream << "{" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
stream << sdtype << " acc = 0;" << std::endl;
|
stream << sdtype << " acc = 0;" << std::endl;
|
||||||
stream << "for(unsigned int k = 0 ; k < D ; k++)" << std::endl;
|
stream << "for(uint32_t k = 0 ; k < D ; k++)" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
stream << "acc += Z[i + j*Zld + k*Zld*N];" << std::endl;
|
stream << "acc += Z[i + j*Zld + k*Zld*N];" << std::endl;
|
||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
@@ -609,7 +609,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
driver::NDRange local(p_.ls0, p_.ls1, 1);
|
driver::NDRange local(p_.ls0, p_.ls1, 1);
|
||||||
driver::NDRange global(align(align(M,p_.mS)/p_.mS, p_.ls0), align(align(N,p_.nS)/p_.nS, p_.ls1), p_.depth);
|
driver::NDRange global(align(align(M,p_.mS)/p_.mS, p_.ls0), align(align(N,p_.nS)/p_.nS, p_.ls1), p_.depth);
|
||||||
|
|
||||||
unsigned int current_arg = 0;
|
uint32_t current_arg = 0;
|
||||||
|
|
||||||
driver::Buffer& workspace = driver::backend::workspaces::get(options.queue(queue.context()));
|
driver::Buffer& workspace = driver::backend::workspaces::get(options.queue(queue.context()));
|
||||||
gemm.setSizeArg(current_arg++, M);
|
gemm.setSizeArg(current_arg++, M);
|
||||||
@@ -656,7 +656,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
|
|
||||||
if(p_.depth > 1)
|
if(p_.depth > 1)
|
||||||
{
|
{
|
||||||
unsigned int current_arg = 0;
|
uint32_t current_arg = 0;
|
||||||
driver::Kernel reduce(program, reduce_name.c_str());
|
driver::Kernel reduce(program, reduce_name.c_str());
|
||||||
driver::NDRange local(p_.ls0, p_.ls1);
|
driver::NDRange local(p_.ls0, p_.ls1);
|
||||||
driver::NDRange global(align(M, p_.ls0), align(N, p_.ls1));
|
driver::NDRange global(align(M, p_.ls0), align(N, p_.ls1));
|
||||||
@@ -721,7 +721,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
gemm_nn::gemm_nn(unsigned int simd
|
gemm_nn::gemm_nn(uint32_t simd
|
||||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns
|
, int_t ms, int_t ks, int_t ns
|
||||||
, fetch_type Afetch , fetch_type Bfetch
|
, fetch_type Afetch , fetch_type Bfetch
|
||||||
@@ -731,7 +731,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
gemm_tn::gemm_tn(unsigned int simd
|
gemm_tn::gemm_tn(uint32_t simd
|
||||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns
|
, int_t ms, int_t ks, int_t ns
|
||||||
, fetch_type Afetch , fetch_type Bfetch
|
, fetch_type Afetch , fetch_type Bfetch
|
||||||
@@ -740,7 +740,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
{ }
|
{ }
|
||||||
|
|
||||||
//
|
//
|
||||||
gemm_nt::gemm_nt(unsigned int simd
|
gemm_nt::gemm_nt(uint32_t simd
|
||||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns
|
, int_t ms, int_t ks, int_t ns
|
||||||
, fetch_type Afetch , fetch_type Bfetch
|
, fetch_type Afetch , fetch_type Bfetch
|
||||||
@@ -749,7 +749,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
|||||||
{ }
|
{ }
|
||||||
|
|
||||||
//
|
//
|
||||||
gemm_tt::gemm_tt(unsigned int simd
|
gemm_tt::gemm_tt(uint32_t simd
|
||||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||||
, int_t ms, int_t ks, int_t ns
|
, int_t ms, int_t ks, int_t ns
|
||||||
, fetch_type Afetch , fetch_type Bfetch
|
, fetch_type Afetch , fetch_type Bfetch
|
||||||
|
@@ -35,12 +35,12 @@ namespace isaac
|
|||||||
{
|
{
|
||||||
namespace templates
|
namespace templates
|
||||||
{
|
{
|
||||||
reduce_1d_parameters::reduce_1d_parameters(unsigned int _vwidth,
|
reduce_1d_parameters::reduce_1d_parameters(uint32_t _vwidth,
|
||||||
unsigned int _group_size, unsigned int _num_groups,
|
uint32_t _group_size, uint32_t _ng,
|
||||||
fetch_type _fetch) : base::parameters_type(_vwidth, _group_size, 1, 2), num_groups(_num_groups), fetch(_fetch)
|
fetch_type _fetch) : base::parameters_type(_vwidth, _group_size, 1, 2), ng(_ng), fetch(_fetch)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
unsigned int reduce_1d::lmem_usage(expression_tree const & x) const
|
uint32_t reduce_1d::lmem_usage(expression_tree const & x) const
|
||||||
{
|
{
|
||||||
return p_.ls0*size_of(x.dtype());
|
return p_.ls0*size_of(x.dtype());
|
||||||
}
|
}
|
||||||
@@ -52,18 +52,18 @@ int reduce_1d::is_invalid_impl(driver::Device const &, expression_tree const &)
|
|||||||
return TEMPLATE_VALID;
|
return TEMPLATE_VALID;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned int reduce_1d::temporary_workspace(expression_tree const &) const
|
uint32_t reduce_1d::temporary_workspace(expression_tree const &) const
|
||||||
{
|
{
|
||||||
if(p_.num_groups > 1)
|
if(p_.ng > 1)
|
||||||
return p_.num_groups;
|
return p_.ng;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void reduce_1d::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<symbolic::reduce_1d*> exprs,
|
inline void reduce_1d::reduce_1d_local_memory(kernel_generation_stream & stream, uint32_t size, std::vector<symbolic::reduce_1d*> exprs,
|
||||||
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type) const
|
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type) const
|
||||||
{
|
{
|
||||||
stream << "#pragma unroll" << std::endl;
|
stream << "#pragma unroll" << std::endl;
|
||||||
stream << "for(unsigned int stride = " << size/2 << "; stride > 0; stride /=2)" << std::endl;
|
stream << "for(uint32_t stride = " << size/2 << "; stride > 0; stride /=2)" << std::endl;
|
||||||
stream << "{" << std::endl;
|
stream << "{" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
stream << "$LOCAL_BARRIER;" << std::endl;
|
stream << "$LOCAL_BARRIER;" << std::endl;
|
||||||
@@ -95,7 +95,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
|
|
||||||
auto unroll_tmp = [&]()
|
auto unroll_tmp = [&]()
|
||||||
{
|
{
|
||||||
unsigned int offset = 0;
|
uint32_t offset = 0;
|
||||||
for(symbolic::reduce_1d* rd: reductions)
|
for(symbolic::reduce_1d* rd: reductions)
|
||||||
{
|
{
|
||||||
numeric_type dtype = tree.dtype();
|
numeric_type dtype = tree.dtype();
|
||||||
@@ -103,13 +103,13 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
if (is_indexing(rd->op().type))
|
if (is_indexing(rd->op().type))
|
||||||
{
|
{
|
||||||
stream << rd->process("$GLOBAL uint* #name_temp = ($GLOBAL uint *)(tmp + " + tools::to_string(offset) + ");");
|
stream << rd->process("$GLOBAL uint* #name_temp = ($GLOBAL uint *)(tmp + " + tools::to_string(offset) + ");");
|
||||||
offset += 4*p_.num_groups;
|
offset += 4*p_.ng;
|
||||||
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp_value = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + ");");
|
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp_value = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + ");");
|
||||||
offset += size_of(dtype)*p_.num_groups;
|
offset += size_of(dtype)*p_.ng;
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + ");");
|
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + ");");
|
||||||
offset += size_of(dtype)*p_.num_groups;
|
offset += size_of(dtype)*p_.ng;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -130,10 +130,10 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
unroll_tmp();
|
unroll_tmp();
|
||||||
//Declare
|
//Declare
|
||||||
stream << "unsigned int lid = $LOCAL_IDX_0;" << std::endl;
|
stream << "uint32_t lid = $LOCAL_IDX_0;" << std::endl;
|
||||||
stream << "unsigned int gid = $GLOBAL_IDX_0;" << std::endl;
|
stream << "uint32_t gid = $GLOBAL_IDX_0;" << std::endl;
|
||||||
stream << "unsigned int gpid = $GROUP_IDX_0;" << std::endl;
|
stream << "uint32_t gpid = $GROUP_IDX_0;" << std::endl;
|
||||||
stream << "unsigned int gsize = $GLOBAL_SIZE_0;" << std::endl;
|
stream << "uint32_t gsize = $GLOBAL_SIZE_0;" << std::endl;
|
||||||
|
|
||||||
for(symbolic::reduce_1d* rd: reductions)
|
for(symbolic::reduce_1d* rd: reductions)
|
||||||
{
|
{
|
||||||
@@ -141,8 +141,8 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
{
|
{
|
||||||
stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(p_.ls0) + "];") << std::endl;
|
stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(p_.ls0) + "];") << std::endl;
|
||||||
stream << rd->process("#scalartype #name_acc_value = " + neutral_element(rd->op(), backend, "#scalartype") + ";") << std::endl;
|
stream << rd->process("#scalartype #name_acc_value = " + neutral_element(rd->op(), backend, "#scalartype") + ";") << std::endl;
|
||||||
stream << rd->process("$LOCAL unsigned int #name_buf[" + tools::to_string(p_.ls0) + "];") << std::endl;
|
stream << rd->process("$LOCAL uint32_t #name_buf[" + tools::to_string(p_.ls0) + "];") << std::endl;
|
||||||
stream << rd->process("unsigned int #name_acc = 0;") << std::endl;
|
stream << rd->process("uint32_t #name_acc = 0;") << std::endl;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@@ -150,7 +150,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
stream << rd->process("#scalartype #name_acc = " + neutral_element(rd->op(), backend, "#scalartype") + ";") << std::endl;
|
stream << rd->process("#scalartype #name_acc = " + neutral_element(rd->op(), backend, "#scalartype") + ";") << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
element_wise_loop_1D(stream, p_.fetch, p_.vwidth, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](unsigned int vwidth)
|
element_wise_loop_1D(stream, p_.fetch, p_.vwidth, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t vwidth)
|
||||||
{
|
{
|
||||||
std::string dtype = append_width("#scalartype",vwidth);
|
std::string dtype = append_width("#scalartype",vwidth);
|
||||||
//Fetch vector entry
|
//Fetch vector entry
|
||||||
@@ -161,7 +161,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
stream << leaf->process(dtype + " #name = " + append_width("loadv", vwidth) + "(i);") << std::endl;
|
stream << leaf->process(dtype + " #name = " + append_width("loadv", vwidth) + "(i);") << std::endl;
|
||||||
//Update accumulators
|
//Update accumulators
|
||||||
for (symbolic::reduce_1d* rd : reductions)
|
for (symbolic::reduce_1d* rd : reductions)
|
||||||
for (unsigned int s = 0; s < vwidth; ++s)
|
for (uint32_t s = 0; s < vwidth; ++s)
|
||||||
{
|
{
|
||||||
std::string value = rd->lhs()->evaluate({{"leaf", access_vector_type("#name", s, vwidth)}});
|
std::string value = rd->lhs()->evaluate({{"leaf", access_vector_type("#name", s, vwidth)}});
|
||||||
if (is_indexing(rd->op().type))
|
if (is_indexing(rd->op().type))
|
||||||
@@ -203,14 +203,14 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
unroll_tmp();
|
unroll_tmp();
|
||||||
//Declarations
|
//Declarations
|
||||||
stream << "unsigned int lid = $LOCAL_IDX_0;" << std::endl;
|
stream << "uint32_t lid = $LOCAL_IDX_0;" << std::endl;
|
||||||
stream << "unsigned int lsize = $LOCAL_SIZE_0;" << std::endl;
|
stream << "uint32_t lsize = $LOCAL_SIZE_0;" << std::endl;
|
||||||
for (symbolic::reduce_1d* rd: reductions)
|
for (symbolic::reduce_1d* rd: reductions)
|
||||||
{
|
{
|
||||||
if (is_indexing(rd->op().type))
|
if (is_indexing(rd->op().type))
|
||||||
{
|
{
|
||||||
stream << rd->process("$LOCAL unsigned int #name_buf[" + tools::to_string(p_.ls0) + "];");
|
stream << rd->process("$LOCAL uint32_t #name_buf[" + tools::to_string(p_.ls0) + "];");
|
||||||
stream << rd->process("unsigned int #name_acc = 0;") << std::endl;
|
stream << rd->process("uint32_t #name_acc = 0;") << std::endl;
|
||||||
stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(p_.ls0) + "];") << std::endl;
|
stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(p_.ls0) + "];") << std::endl;
|
||||||
stream << rd->process("#scalartype #name_acc_value = " + neutral_element(rd->op(), backend, "#scalartype") + ";");
|
stream << rd->process("#scalartype #name_acc_value = " + neutral_element(rd->op(), backend, "#scalartype") + ";");
|
||||||
}
|
}
|
||||||
@@ -221,7 +221,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
//Private reduction
|
//Private reduction
|
||||||
stream << "for(unsigned int i = lid; i < " << p_.num_groups << "; i += lsize)" << std::endl;
|
stream << "for(uint32_t i = lid; i < " << p_.ng << "; i += lsize)" << std::endl;
|
||||||
stream << "{" << std::endl;
|
stream << "{" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
for (symbolic::reduce_1d* rd: reductions)
|
for (symbolic::reduce_1d* rd: reductions)
|
||||||
@@ -256,7 +256,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
reduce_1d::reduce_1d(reduce_1d::parameters_type const & parameters) : base_impl<reduce_1d, reduce_1d_parameters>(parameters)
|
reduce_1d::reduce_1d(reduce_1d::parameters_type const & parameters) : base_impl<reduce_1d, reduce_1d_parameters>(parameters)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
reduce_1d::reduce_1d(unsigned int simd, unsigned int ls, unsigned int ng, fetch_type fetch):
|
reduce_1d::reduce_1d(uint32_t simd, uint32_t ls, uint32_t ng, fetch_type fetch):
|
||||||
base_impl<reduce_1d, reduce_1d_parameters>(reduce_1d_parameters(simd,ls,ng,fetch))
|
base_impl<reduce_1d, reduce_1d_parameters>(reduce_1d_parameters(simd,ls,ng,fetch))
|
||||||
{}
|
{}
|
||||||
|
|
||||||
@@ -282,18 +282,18 @@ void reduce_1d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
|
|||||||
driver::Kernel kernels[2] = { driver::Kernel(program,name[0].c_str()), driver::Kernel(program,name[1].c_str()) };
|
driver::Kernel kernels[2] = { driver::Kernel(program,name[0].c_str()), driver::Kernel(program,name[1].c_str()) };
|
||||||
|
|
||||||
//NDRange
|
//NDRange
|
||||||
driver::NDRange global[2] = { driver::NDRange(p_.ls0*p_.num_groups), driver::NDRange(p_.ls0) };
|
driver::NDRange global[2] = { driver::NDRange(p_.ls0*p_.ng), driver::NDRange(p_.ls0) };
|
||||||
driver::NDRange local[2] = { driver::NDRange(p_.ls0), driver::NDRange(p_.ls0) };
|
driver::NDRange local[2] = { driver::NDRange(p_.ls0), driver::NDRange(p_.ls0) };
|
||||||
//Arguments
|
//Arguments
|
||||||
for (auto & kernel : kernels)
|
for (auto & kernel : kernels)
|
||||||
{
|
{
|
||||||
unsigned int n_arg = 0;
|
uint32_t n_arg = 0;
|
||||||
kernel.setSizeArg(n_arg++, size);
|
kernel.setSizeArg(n_arg++, size);
|
||||||
kernel.setArg(n_arg++, driver::backend::workspaces::get(queue));
|
kernel.setArg(n_arg++, driver::backend::workspaces::get(queue));
|
||||||
symbolic::set_arguments(x, kernel, n_arg);
|
symbolic::set_arguments(x, kernel, n_arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned int k = 0; k < 2; k++)
|
for (uint32_t k = 0; k < 2; k++)
|
||||||
control.execution_options().enqueue(program.context(), kernels[k], global[k], local[k]);
|
control.execution_options().enqueue(program.context(), kernels[k], global[k], local[k]);
|
||||||
queue.synchronize();
|
queue.synchronize();
|
||||||
}
|
}
|
||||||
|
@@ -39,10 +39,10 @@ namespace isaac
|
|||||||
namespace templates
|
namespace templates
|
||||||
{
|
{
|
||||||
|
|
||||||
reduce_2d_parameters::reduce_2d_parameters(unsigned int _vwidth,
|
reduce_2d_parameters::reduce_2d_parameters(uint32_t _vwidth,
|
||||||
unsigned int _ls0, unsigned int _ls1,
|
uint32_t _ls0, uint32_t _ls1,
|
||||||
unsigned int _num_groups_0, unsigned int _num_groups_1, fetch_type _fetch_policy): base::parameters_type(_vwidth, _ls0, _ls1, 1),
|
uint32_t _ng0, uint32_t _ng1, fetch_type _fetch_policy): base::parameters_type(_vwidth, _ls0, _ls1, 1),
|
||||||
num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetch_policy(_fetch_policy) { }
|
ng0(_ng0), ng1(_ng1), fetch_policy(_fetch_policy) { }
|
||||||
|
|
||||||
|
|
||||||
int reduce_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
int reduce_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||||
@@ -52,17 +52,17 @@ int reduce_2d::is_invalid_impl(driver::Device const &, expression_tree const &)
|
|||||||
return TEMPLATE_VALID;
|
return TEMPLATE_VALID;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned int reduce_2d::lmem_usage(const expression_tree&) const
|
uint32_t reduce_2d::lmem_usage(const expression_tree&) const
|
||||||
{
|
{
|
||||||
return (p_.ls0+1)*p_.ls1;
|
return (p_.ls0+1)*p_.ls1;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned int reduce_2d::temporary_workspace(expression_tree const & expressions) const
|
uint32_t reduce_2d::temporary_workspace(expression_tree const & expressions) const
|
||||||
{
|
{
|
||||||
std::vector<int_t> MN = input_sizes(expressions);
|
std::vector<int_t> MN = input_sizes(expressions);
|
||||||
int_t M = MN[0];
|
int_t M = MN[0];
|
||||||
if(p_.num_groups_0 > 1)
|
if(p_.ng0 > 1)
|
||||||
return M*p_.num_groups_0;
|
return M*p_.ng0;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,12 +80,12 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
name[0] += suffix;
|
name[0] += suffix;
|
||||||
name[1] += suffix;
|
name[1] += suffix;
|
||||||
|
|
||||||
unsigned int ldls = p_.ls0;
|
uint32_t ldls = p_.ls0;
|
||||||
std::string ls0ldstr = to_string(ldls);
|
std::string ls0ldstr = to_string(ldls);
|
||||||
|
|
||||||
auto unroll_tmp = [&]()
|
auto unroll_tmp = [&]()
|
||||||
{
|
{
|
||||||
unsigned int offset = 0;
|
uint32_t offset = 0;
|
||||||
for (symbolic::reduce_2d* rd : reductions)
|
for (symbolic::reduce_2d* rd : reductions)
|
||||||
{
|
{
|
||||||
numeric_type dtype = tree.dtype();
|
numeric_type dtype = tree.dtype();
|
||||||
@@ -93,13 +93,13 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
if (is_indexing(rd->op().type))
|
if (is_indexing(rd->op().type))
|
||||||
{
|
{
|
||||||
stream << rd->process("$GLOBAL uint* #name_temp = ($GLOBAL uint*)(tmp + " + tools::to_string(offset) + "*M);");
|
stream << rd->process("$GLOBAL uint* #name_temp = ($GLOBAL uint*)(tmp + " + tools::to_string(offset) + "*M);");
|
||||||
offset += 4*p_.num_groups_0;
|
offset += 4*p_.ng0;
|
||||||
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp_value = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
|
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp_value = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
|
||||||
offset += size_of(dtype)*p_.num_groups_0;
|
offset += size_of(dtype)*p_.ng0;
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
|
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
|
||||||
offset += size_of(dtype)*p_.num_groups_0;
|
offset += size_of(dtype)*p_.ng0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -127,7 +127,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
std::ostringstream upper;
|
std::ostringstream upper;
|
||||||
upper << "(M +" << p_.ls1 - 1 << ")/" << p_.ls1 << "*" << p_.ls1;
|
upper << "(M +" << p_.ls1 - 1 << ")/" << p_.ls1 << "*" << p_.ls1;
|
||||||
|
|
||||||
element_wise_loop_1D(stream, p_.fetch_policy, (reduction_type_==REDUCE_ROWS)?1:1, "r", upper.str(), "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device, [&](unsigned int cwidth)
|
element_wise_loop_1D(stream, p_.fetch_policy, (reduction_type_==REDUCE_ROWS)?1:1, "r", upper.str(), "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device, [&](uint32_t cwidth)
|
||||||
{
|
{
|
||||||
//Declare Buffers
|
//Declare Buffers
|
||||||
for (symbolic::reduce_2d* rd : reductions)
|
for (symbolic::reduce_2d* rd : reductions)
|
||||||
@@ -142,7 +142,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
stream << "if (r < M)" << std::endl;
|
stream << "if (r < M)" << std::endl;
|
||||||
stream << "{" << std::endl;
|
stream << "{" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
element_wise_loop_1D(stream, p_.fetch_policy, (reduction_type_==REDUCE_COLUMNS)?p_.vwidth:1, "c", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](unsigned int rwidth)
|
element_wise_loop_1D(stream, p_.fetch_policy, (reduction_type_==REDUCE_COLUMNS)?p_.vwidth:1, "c", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t rwidth)
|
||||||
{
|
{
|
||||||
std::string rdtype = append_width("#scalartype", rwidth);
|
std::string rdtype = append_width("#scalartype", rwidth);
|
||||||
std::string cdtype = append_width("#scalartype", cwidth);
|
std::string cdtype = append_width("#scalartype", cwidth);
|
||||||
@@ -158,7 +158,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
}
|
}
|
||||||
//Compute
|
//Compute
|
||||||
for (symbolic::reduce_2d* rd : reductions)
|
for (symbolic::reduce_2d* rd : reductions)
|
||||||
for (unsigned int s = 0; s < rwidth; ++s){
|
for (uint32_t s = 0; s < rwidth; ++s){
|
||||||
std::string value = rd->lhs()->evaluate({{"leaf", access_vector_type("#name", s, rwidth)}});
|
std::string value = rd->lhs()->evaluate({{"leaf", access_vector_type("#name", s, rwidth)}});
|
||||||
if (is_indexing(rd->op().type))
|
if (is_indexing(rd->op().type))
|
||||||
compute_index_reduce_1d(stream, rd->process("#name_acc"), "c*"+to_string(rwidth) + to_string(s), rd->process("#name_acc_value"), value, rd->op());
|
compute_index_reduce_1d(stream, rd->process("#name_acc"), "c*"+to_string(rwidth) + to_string(s), rd->process("#name_acc_value"), value, rd->op());
|
||||||
@@ -195,7 +195,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
stream << "if (r < M && lidx == 0)" << std::endl;
|
stream << "if (r < M && lidx == 0)" << std::endl;
|
||||||
stream << "{" << std::endl;
|
stream << "{" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
if(p_.num_groups_0==1)
|
if(p_.ng0==1)
|
||||||
for(size_t idx: assignments)
|
for(size_t idx: assignments)
|
||||||
for(size_t s = 0 ; s < cwidth ; ++s)
|
for(size_t s = 0 ; s < cwidth ; ++s)
|
||||||
stream << symbols.at(idx)->evaluate({{"leaf", "at(r+" + to_string(s) + ")"},
|
stream << symbols.at(idx)->evaluate({{"leaf", "at(r+" + to_string(s) + ")"},
|
||||||
@@ -217,7 +217,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
/* ------------------------
|
/* ------------------------
|
||||||
* Kernel 2
|
* Kernel 2
|
||||||
* -----------------------*/
|
* -----------------------*/
|
||||||
if(p_.num_groups_0>1)
|
if(p_.ng0>1)
|
||||||
{
|
{
|
||||||
if(backend==driver::OPENCL)
|
if(backend==driver::OPENCL)
|
||||||
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << "," << p_.ls1 << ",1)))" << std::endl;
|
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << "," << p_.ls1 << ",1)))" << std::endl;
|
||||||
@@ -236,7 +236,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
|||||||
stream << "if (r < M)" << std::endl;
|
stream << "if (r < M)" << std::endl;
|
||||||
stream << "{" << std::endl;
|
stream << "{" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
stream << "for($SIZE_T c = lidx; c < " << p_.num_groups_0 << "; c += $LOCAL_SIZE_0){" << std::endl;
|
stream << "for($SIZE_T c = lidx; c < " << p_.ng0 << "; c += $LOCAL_SIZE_0){" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
for (symbolic::reduce_2d* rd: reductions)
|
for (symbolic::reduce_2d* rd: reductions)
|
||||||
compute_reduce_1d(stream, rd->process("#name_acc"), rd->process("#name_temp[r + M*c]"), rd->op());
|
compute_reduce_1d(stream, rd->process("#name_acc"), rd->process("#name_temp[r + M*c]"), rd->op());
|
||||||
@@ -306,16 +306,16 @@ void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
|
|||||||
name[0] += suffix;
|
name[0] += suffix;
|
||||||
name[1] += suffix;
|
name[1] += suffix;
|
||||||
|
|
||||||
unsigned int nk = (p_.num_groups_0==1)?1:2;
|
uint32_t nk = (p_.ng0==1)?1:2;
|
||||||
|
|
||||||
std::vector<driver::Kernel> kernels;
|
std::vector<driver::Kernel> kernels;
|
||||||
for(unsigned int k = 0 ; k < nk ; ++k)
|
for(uint32_t k = 0 ; k < nk ; ++k)
|
||||||
kernels.push_back(driver::Kernel(program, name[k].c_str()));
|
kernels.push_back(driver::Kernel(program, name[k].c_str()));
|
||||||
|
|
||||||
for(unsigned int k = 0 ; k < nk ; ++k)
|
for(uint32_t k = 0 ; k < nk ; ++k)
|
||||||
{
|
{
|
||||||
driver::Kernel & kernel = kernels[k];
|
driver::Kernel & kernel = kernels[k];
|
||||||
unsigned int n_arg = 0;
|
uint32_t n_arg = 0;
|
||||||
int_t M = MN[0];
|
int_t M = MN[0];
|
||||||
int_t N = MN[1];
|
int_t N = MN[1];
|
||||||
kernel.setSizeArg(n_arg++, M);
|
kernel.setSizeArg(n_arg++, M);
|
||||||
@@ -325,20 +325,20 @@ void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
|
|||||||
}
|
}
|
||||||
|
|
||||||
//NDRange
|
//NDRange
|
||||||
driver::NDRange global[2] = { driver::NDRange(p_.ls0*p_.num_groups_0, p_.ls1*p_.num_groups_1), driver::NDRange(p_.ls0, p_.ls1*p_.num_groups_1) };
|
driver::NDRange global[2] = { driver::NDRange(p_.ls0*p_.ng0, p_.ls1*p_.ng1), driver::NDRange(p_.ls0, p_.ls1*p_.ng1) };
|
||||||
driver::NDRange local[2] = { driver::NDRange(p_.ls0, p_.ls1), driver::NDRange(p_.ls0, p_.ls1) };
|
driver::NDRange local[2] = { driver::NDRange(p_.ls0, p_.ls1), driver::NDRange(p_.ls0, p_.ls1) };
|
||||||
for(unsigned int i = 0 ; i < nk ; ++i)
|
for(uint32_t i = 0 ; i < nk ; ++i)
|
||||||
control.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
|
control.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
reduce_2d_rows::reduce_2d_rows(reduce_2d_parameters const & parameters): reduce_2d(parameters, REDUCE_ROWS){}
|
reduce_2d_rows::reduce_2d_rows(reduce_2d_parameters const & parameters): reduce_2d(parameters, REDUCE_ROWS){}
|
||||||
|
|
||||||
reduce_2d_rows::reduce_2d_rows(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2,
|
reduce_2d_rows::reduce_2d_rows(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2,
|
||||||
fetch_type fetch): reduce_2d(reduce_2d_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_ROWS) {}
|
fetch_type fetch): reduce_2d(reduce_2d_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_ROWS) {}
|
||||||
|
|
||||||
reduce_2d_cols::reduce_2d_cols(reduce_2d::parameters_type const & parameters): reduce_2d(parameters, REDUCE_COLUMNS){}
|
reduce_2d_cols::reduce_2d_cols(reduce_2d::parameters_type const & parameters): reduce_2d(parameters, REDUCE_COLUMNS){}
|
||||||
|
|
||||||
reduce_2d_cols::reduce_2d_cols(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2,
|
reduce_2d_cols::reduce_2d_cols(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2,
|
||||||
fetch_type fetch): reduce_2d(reduce_2d_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_COLUMNS) {}
|
fetch_type fetch): reduce_2d(reduce_2d_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_COLUMNS) {}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -96,7 +96,7 @@ std::string hash(expression_tree const & tree)
|
|||||||
}
|
}
|
||||||
|
|
||||||
//Set arguments
|
//Set arguments
|
||||||
void set_arguments(expression_tree const & tree, driver::Kernel & kernel, unsigned int & current_arg)
|
void set_arguments(expression_tree const & tree, driver::Kernel & kernel, uint32_t& current_arg)
|
||||||
{
|
{
|
||||||
driver::backend_type backend = tree.context().backend();
|
driver::backend_type backend = tree.context().backend();
|
||||||
|
|
||||||
|
@@ -77,7 +77,7 @@ profiles::value_type::value_type(expression_type etype, numeric_type dtype, pred
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
profiles::value_type::value_type(expression_type etype, numeric_type dtype, templates::base const & tp, driver::CommandQueue const & queue) : templates_(1,tp.clone()), queue_(queue), cache_(driver::backend::programs::get(queue,etype,dtype))
|
profiles::value_type::value_type(expression_type etype, numeric_type dtype, std::shared_ptr<templates::base> const & tp, driver::CommandQueue const & queue) : templates_(1,tp), queue_(queue), cache_(driver::backend::programs::get(queue,etype,dtype))
|
||||||
{
|
{
|
||||||
cache_.clear();
|
cache_.clear();
|
||||||
}
|
}
|
||||||
@@ -197,7 +197,7 @@ void profiles::import(std::string const & str, driver::CommandQueue const & queu
|
|||||||
result[{etype, dtype}] = std::make_shared<value_type>(etype, dtype, predictor, templates, queue);
|
result[{etype, dtype}] = std::make_shared<value_type>(etype, dtype, predictor, templates, queue);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
result[{etype, dtype}] = std::make_shared<value_type>(etype, dtype, *templates[0], queue);
|
result[{etype, dtype}] = std::make_shared<value_type>(etype, dtype, templates[0], queue);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -73,7 +73,7 @@ def main():
|
|||||||
libraries += ['gnustl_shared']
|
libraries += ['gnustl_shared']
|
||||||
|
|
||||||
#Source files
|
#Source files
|
||||||
src = 'src/lib/exception/api.cpp src/lib/exception/driver.cpp src/lib/value_scalar.cpp src/lib/random/rand.cpp src/lib/driver/check.cpp src/lib/driver/ndrange.cpp src/lib/driver/platform.cpp src/lib/driver/backend.cpp src/lib/driver/program.cpp src/lib/driver/command_queue.cpp src/lib/driver/event.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/driver/device.cpp src/lib/driver/program_cache.cpp src/lib/driver/buffer.cpp src/lib/driver/context.cpp src/lib/driver/dispatch.cpp src/lib/jit/generation/engine/stream.cpp src/lib/jit/generation/engine/keywords.cpp src/lib/jit/generation/reduce_1d.cpp src/lib/jit/generation/elementwise_1d.cpp src/lib/jit/generation/base.cpp src/lib/jit/generation/elementwise_2d.cpp src/lib/jit/generation/reduce_2d.cpp src/lib/jit/generation/gemm.cpp src/lib/jit/syntax/engine/object.cpp src/lib/jit/syntax/engine/macro.cpp src/lib/jit/syntax/engine/process.cpp src/lib/jit/syntax/engine/binder.cpp src/lib/jit/syntax/expression/operations.cpp src/lib/jit/syntax/expression/expression.cpp src/lib/jit/syntax/expression/preset.cpp src/lib/api/blas/clBLAS.cpp src/lib/api/blas/cublas.cpp src/lib/runtime/execute.cpp src/lib/runtime/predictors/random_forest.cpp src/lib/runtime/profiles.cpp src/lib/runtime/database.cpp src/lib/array.cpp '.split() + [os.path.join('src', 'bind', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'kernels.cpp', 'exceptions.cpp']]
|
src = 'src/lib/random/rand.cpp src/lib/jit/syntax/expression/preset.cpp src/lib/jit/syntax/expression/expression.cpp src/lib/jit/syntax/expression/operations.cpp src/lib/jit/syntax/engine/macro.cpp src/lib/jit/syntax/engine/object.cpp src/lib/jit/syntax/engine/process.cpp src/lib/jit/syntax/engine/binder.cpp src/lib/jit/generation/reduce_2d.cpp src/lib/jit/generation/elementwise_2d.cpp src/lib/jit/generation/engine/stream.cpp src/lib/jit/generation/engine/keywords.cpp src/lib/jit/generation/elementwise_1d.cpp src/lib/jit/generation/reduce_1d.cpp src/lib/jit/generation/gemm.cpp src/lib/jit/generation/base.cpp src/lib/runtime/execute.cpp src/lib/runtime/database.cpp src/lib/runtime/profiles.cpp src/lib/runtime/predictors/random_forest.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/driver/backend.cpp src/lib/driver/device.cpp src/lib/driver/kernel.cpp src/lib/driver/buffer.cpp src/lib/driver/platform.cpp src/lib/driver/check.cpp src/lib/driver/program.cpp src/lib/driver/command_queue.cpp src/lib/driver/dispatch.cpp src/lib/driver/program_cache.cpp src/lib/driver/context.cpp src/lib/driver/event.cpp src/lib/driver/ndrange.cpp src/lib/driver/handle.cpp src/lib/api/blas/clBLAS.cpp src/lib/api/blas/cublas.cpp src/lib/exception/api.cpp src/lib/exception/driver.cpp '.split() + [os.path.join('src', 'bind', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'kernels.cpp', 'exceptions.cpp']]
|
||||||
boostsrc = 'external/boost/libs/'
|
boostsrc = 'external/boost/libs/'
|
||||||
for s in ['numpy','python','smart_ptr','system','thread']:
|
for s in ['numpy','python','smart_ptr','system','thread']:
|
||||||
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]
|
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]
|
||||||
|
Reference in New Issue
Block a user