Code Quality: Reverted uint32_t to unsigned int

Caused problem with boost python on some platforms but not others, no time to figure out why.
This commit is contained in:
Philippe Tillet
2016-10-03 02:53:47 -04:00
parent 31849794e8
commit fca79c317e
16 changed files with 155 additions and 155 deletions

View File

@@ -78,9 +78,9 @@ private:
public: public:
base(); base();
virtual ~base(); virtual ~base();
virtual uint32_t temporary_workspace(expression_tree const &) const; virtual unsigned int temporary_workspace(expression_tree const &) const;
virtual uint32_t lmem_usage(expression_tree const &) const; virtual unsigned int lmem_usage(expression_tree const &) const;
virtual uint32_t registers_usage(expression_tree const &) const; virtual unsigned int registers_usage(expression_tree const &) const;
virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const = 0; virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const = 0;
virtual int is_invalid(expression_tree const & expressions, driver::Device const & device) const = 0; virtual int is_invalid(expression_tree const & expressions, driver::Device const & device) const = 0;
virtual void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & expressions) = 0; virtual void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & expressions) = 0;
@@ -96,15 +96,15 @@ class base_impl : public base
private: private:
virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const; virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const;
public: public:
base_impl(uint32_t _vwidth, int_t _ls0, int_t _ls1); base_impl(unsigned int _vwidth, int_t _ls0, int_t _ls1);
uint32_t ls0() const; unsigned int ls0() const;
uint32_t ls1() const; unsigned int ls1() const;
/** @brief returns whether or not the profile has undefined behavior on particular device */ /** @brief returns whether or not the profile has undefined behavior on particular device */
int is_invalid(expression_tree const & expressions, driver::Device const & device) const; int is_invalid(expression_tree const & expressions, driver::Device const & device) const;
protected: protected:
uint32_t vwidth_; unsigned int vwidth_;
uint32_t ls0_; unsigned int ls0_;
uint32_t ls1_; unsigned int ls1_;
}; };
} }

View File

@@ -35,11 +35,11 @@ private:
virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const; virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const;
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & symbols) const; std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & symbols) const;
public: public:
elementwise_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch); elementwise_1d(unsigned int vwidth, unsigned int ls, unsigned int ng, fetch_type fetch);
std::vector<int_t> input_sizes(expression_tree const & expressions) const; std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
private: private:
uint32_t ng_; unsigned int ng_;
fetch_type fetch_; fetch_type fetch_;
}; };

View File

@@ -36,12 +36,12 @@ private:
int is_invalid_impl(driver::Device const &, expression_tree const &) const; int is_invalid_impl(driver::Device const &, expression_tree const &) const;
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const; std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const;
public: public:
elementwise_2d(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch); elementwise_2d(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, fetch_type fetch);
std::vector<int_t> input_sizes(expression_tree const & expressions) const; std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
private: private:
uint32_t ng0_; unsigned int ng0_;
uint32_t ng1_; unsigned int ng1_;
fetch_type fetch_; fetch_type fetch_;
}; };

View File

@@ -33,12 +33,12 @@ class kernel_generation_stream : public std::ostream
class kgenstream : public std::stringbuf class kgenstream : public std::stringbuf
{ {
public: public:
kgenstream(std::ostringstream& oss,uint32_t const & tab_count) ; kgenstream(std::ostringstream& oss,unsigned int const & tab_count) ;
int sync(); int sync();
~kgenstream(); ~kgenstream();
private: private:
std::ostream& oss_; std::ostream& oss_;
uint32_t const & tab_count_; unsigned int const & tab_count_;
}; };
void process(std::string& str); void process(std::string& str);
@@ -51,7 +51,7 @@ public:
void inc_tab(); void inc_tab();
void dec_tab(); void dec_tab();
private: private:
uint32_t tab_count_; unsigned int tab_count_;
driver::backend_type backend_; driver::backend_type backend_;
std::ostringstream oss; std::ostringstream oss;
}; };

View File

@@ -34,37 +34,37 @@ namespace templates
class gemm : public base_impl class gemm : public base_impl
{ {
private: private:
uint32_t temporary_workspace(expression_tree const & expressions) const; unsigned int temporary_workspace(expression_tree const & expressions) const;
uint32_t lmem_usage(expression_tree const & expressions) const; unsigned int lmem_usage(expression_tree const & expressions) const;
uint32_t registers_usage(expression_tree const & expressions) const; unsigned int registers_usage(expression_tree const & expressions) const;
int is_invalid_impl(driver::Device const &, expression_tree const &) const; int is_invalid_impl(driver::Device const &, expression_tree const &) const;
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const &) const; std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const &) const;
void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, const expression_tree::node &A, const expression_tree::node &B, const expression_tree::node &C, void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, const expression_tree::node &A, const expression_tree::node &B, const expression_tree::node &C,
value_scalar const &alpha, value_scalar const &beta, driver::Program const & program, std::string const & suffix, runtime::execution_options_type const & options); value_scalar const &alpha, value_scalar const &beta, driver::Program const & program, std::string const & suffix, runtime::execution_options_type const & options);
std::vector<int_t> infos(expression_tree const & expressions, isaac::symbolic::preset::gemm::args &arguments) const; std::vector<int_t> infos(expression_tree const & expressions, isaac::symbolic::preset::gemm::args &arguments) const;
public: public:
gemm(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D gemm(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch , int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
, int_t lf0, int_t lf1, char A_trans, char B_trans); , int_t lf0, int_t lf1, char A_trans, char B_trans);
std::vector<int_t> input_sizes(expression_tree const & expressions) const; std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &ctr); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &ctr);
private: private:
//Parameters //Parameters
uint32_t kL_; unsigned int kL_;
uint32_t depth_; unsigned int depth_;
uint32_t mS_; unsigned int mS_;
uint32_t kS_; unsigned int kS_;
uint32_t nS_; unsigned int nS_;
fetch_type Afetch_; fetch_type Afetch_;
fetch_type Bfetch_; fetch_type Bfetch_;
uint32_t lf0_; unsigned int lf0_;
uint32_t lf1_; unsigned int lf1_;
uint32_t mL_; unsigned int mL_;
uint32_t nL_; unsigned int nL_;
bool prefetch_; bool prefetch_;
bool unroll_outer_; bool unroll_outer_;
@@ -77,7 +77,7 @@ private:
class gemm_nn : public gemm class gemm_nn : public gemm
{ {
public: public:
gemm_nn(uint32_t vwidth, int_t ls0, int_t KL, int_t ls1, int_t D gemm_nn(unsigned int vwidth, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch , int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
, int_t lf0, int_t lf1); , int_t lf0, int_t lf1);
}; };
@@ -85,7 +85,7 @@ public:
class gemm_tn : public gemm class gemm_tn : public gemm
{ {
public: public:
gemm_tn(uint32_t vwidth, int_t ls0, int_t KL, int_t ls1, int_t D gemm_tn(unsigned int vwidth, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch , int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
, int_t lf0, int_t lf1); , int_t lf0, int_t lf1);
}; };
@@ -94,7 +94,7 @@ public:
class gemm_nt : public gemm class gemm_nt : public gemm
{ {
public: public:
gemm_nt(uint32_t vwidth, int_t ls0, int_t KL, int_t ls1, int_t D gemm_nt(unsigned int vwidth, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch , int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
, int_t lf0, int_t lf1); , int_t lf0, int_t lf1);
}; };
@@ -103,7 +103,7 @@ public:
class gemm_tt : public gemm class gemm_tt : public gemm
{ {
public: public:
gemm_tt(uint32_t vwidth, int_t ls0, int_t KL, int_t ls1, int_t D gemm_tt(unsigned int vwidth, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch , int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
, int_t lf0, int_t lf1); , int_t lf0, int_t lf1);
}; };

View File

@@ -32,19 +32,19 @@ namespace templates
class reduce_1d : public base_impl class reduce_1d : public base_impl
{ {
private: private:
uint32_t lmem_usage(expression_tree const & expressions) const; unsigned int lmem_usage(expression_tree const & expressions) const;
int is_invalid_impl(driver::Device const &, expression_tree const &) const; int is_invalid_impl(driver::Device const &, expression_tree const &) const;
uint32_t temporary_workspace(expression_tree const & expressions) const; unsigned int temporary_workspace(expression_tree const & expressions) const;
inline void reduce_1d_local_memory(kernel_generation_stream & stream, uint32_t size, std::vector<symbolic::reduce_1d*> exprs, inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<symbolic::reduce_1d*> exprs,
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const; std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const;
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const; std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const;
public: public:
reduce_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch); reduce_1d(unsigned int vwidth, unsigned int ls, unsigned int ng, fetch_type fetch);
std::vector<int_t> input_sizes(expression_tree const & expressions) const; std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
private: private:
uint32_t ng_; unsigned int ng_;
fetch_type fetch_; fetch_type fetch_;
std::vector< driver::Buffer > tmp_; std::vector< driver::Buffer > tmp_;
std::vector< driver::Buffer > tmpidx_; std::vector< driver::Buffer > tmpidx_;

View File

@@ -35,18 +35,18 @@ namespace templates
class reduce_2d : public base_impl class reduce_2d : public base_impl
{ {
protected: protected:
reduce_2d(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch, operation_type_family); reduce_2d(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, fetch_type fetch, operation_type_family);
private: private:
int is_invalid_impl(driver::Device const &, expression_tree const &) const; int is_invalid_impl(driver::Device const &, expression_tree const &) const;
uint32_t lmem_usage(expression_tree const &) const; unsigned int lmem_usage(expression_tree const &) const;
uint32_t temporary_workspace(expression_tree const & expressions) const; unsigned int temporary_workspace(expression_tree const & expressions) const;
std::string generate_impl(std::string const & suffix, expression_tree const &, driver::Device const & device, symbolic::symbols_table const &) const; std::string generate_impl(std::string const & suffix, expression_tree const &, driver::Device const & device, symbolic::symbols_table const &) const;
public: public:
virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const; virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
private: private:
uint32_t ng0_; unsigned int ng0_;
uint32_t ng1_; unsigned int ng1_;
fetch_type fetch_; fetch_type fetch_;
operation_type_family reduction_type_; operation_type_family reduction_type_;
}; };
@@ -54,13 +54,13 @@ private:
class reduce_2d_rows : public reduce_2d class reduce_2d_rows : public reduce_2d
{ {
public: public:
reduce_2d_rows(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch); reduce_2d_rows(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, fetch_type fetch);
}; };
class reduce_2d_cols : public reduce_2d class reduce_2d_cols : public reduce_2d
{ {
public: public:
reduce_2d_cols(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch); reduce_2d_cols(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, fetch_type fetch);
}; };
} }

View File

@@ -90,7 +90,7 @@ std::vector<size_t> rhs_of(expression_tree const & tree, std::vector<size_t> con
std::string hash(expression_tree const & tree); std::string hash(expression_tree const & tree);
//Set arguments //Set arguments
void set_arguments(expression_tree const & tree, driver::Kernel & kernel, uint32_t& current_arg); void set_arguments(expression_tree const & tree, driver::Kernel & kernel, unsigned int& current_arg);
//Symbolize //Symbolize
symbols_table symbolize(isaac::expression_tree const & expression); symbols_table symbolize(isaac::expression_tree const & expression);

View File

@@ -43,13 +43,13 @@ namespace templates
base::base() base::base()
{} {}
uint32_t base::lmem_usage(expression_tree const &) const unsigned int base::lmem_usage(expression_tree const &) const
{ return 0; } { return 0; }
uint32_t base::registers_usage(expression_tree const &) const unsigned int base::registers_usage(expression_tree const &) const
{ return 0; } { return 0; }
uint32_t base::temporary_workspace(expression_tree const &) const unsigned int base::temporary_workspace(expression_tree const &) const
{ return 0; } { return 0; }
base::~base() base::~base()
@@ -69,13 +69,13 @@ std::string base::generate(std::string const & suffix, expression_tree const &
int base_impl::is_invalid_impl(driver::Device const &, expression_tree const &) const int base_impl::is_invalid_impl(driver::Device const &, expression_tree const &) const
{ return TEMPLATE_VALID; } { return TEMPLATE_VALID; }
base_impl::base_impl(uint32_t vwidth, int_t ls0, int_t ls1): vwidth_(vwidth), ls0_(ls0), ls1_(ls1) base_impl::base_impl(unsigned int vwidth, int_t ls0, int_t ls1): vwidth_(vwidth), ls0_(ls0), ls1_(ls1)
{ } { }
uint32_t base_impl::ls0() const unsigned int base_impl::ls0() const
{ return ls0_; } { return ls0_; }
uint32_t base_impl::ls1() const unsigned int base_impl::ls1() const
{ return ls1_; } { return ls1_; }
int base_impl::is_invalid(expression_tree const & expressions, driver::Device const & device) const int base_impl::is_invalid(expression_tree const & expressions, driver::Device const & device) const

View File

@@ -75,7 +75,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
stream.inc_tab(); stream.inc_tab();
} }
element_wise_loop_1D(stream, fetch_, vwidth_, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t vwidth) element_wise_loop_1D(stream, fetch_, vwidth_, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](unsigned int vwidth)
{ {
std::string dtype = append_width("#scalartype",vwidth); std::string dtype = append_width("#scalartype",vwidth);
@@ -89,12 +89,12 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
//Compute //Compute
for(size_t idx: assignments) for(size_t idx: assignments)
for(uint32_t s = 0 ; s < vwidth ; ++s) for(unsigned int s = 0 ; s < vwidth ; ++s)
stream << symbols.at(idx)->evaluate({{"leaf", access_vector_type("#name", s, vwidth)}}) << ";" << std::endl; stream << symbols.at(idx)->evaluate({{"leaf", access_vector_type("#name", s, vwidth)}}) << ";" << std::endl;
//Writes back //Writes back
for(symbolic::leaf* sym: symbolic::extract<symbolic::leaf>(tree, symbols, assignments_lhs, false)) for(symbolic::leaf* sym: symbolic::extract<symbolic::leaf>(tree, symbols, assignments_lhs, false))
for(uint32_t s = 0 ; s < vwidth ; ++s) for(unsigned int s = 0 ; s < vwidth ; ++s)
stream << sym->process("at(i+" + tools::to_string(s)+") = " + access_vector_type("#name", s, vwidth) + ";") << std::endl; stream << sym->process("at(i+" + tools::to_string(s)+") = " + access_vector_type("#name", s, vwidth) + ";") << std::endl;
}); });
//Close user-provided for-loops //Close user-provided for-loops
@@ -110,7 +110,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
return stream.str(); return stream.str();
} }
elementwise_1d::elementwise_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch): elementwise_1d::elementwise_1d(unsigned int vwidth, unsigned int ls, unsigned int ng, fetch_type fetch):
base_impl(vwidth,ls,1), ng_(ng), fetch_(fetch) base_impl(vwidth,ls,1), ng_(ng), fetch_(fetch)
{} {}
@@ -133,7 +133,7 @@ void elementwise_1d::enqueue(driver::CommandQueue &, driver::Program const & pro
driver::NDRange global(ls0_*ng_); driver::NDRange global(ls0_*ng_);
driver::NDRange local(ls0_); driver::NDRange local(ls0_);
//Arguments //Arguments
uint32_t current_arg = 0; unsigned int current_arg = 0;
kernel.setSizeArg(current_arg++, size); kernel.setSizeArg(current_arg++, size);
symbolic::set_arguments(expressions, kernel, current_arg); symbolic::set_arguments(expressions, kernel, current_arg);
control.execution_options().enqueue(program.context(), kernel, global, local); control.execution_options().enqueue(program.context(), kernel, global, local);

View File

@@ -104,8 +104,8 @@ std::string elementwise_2d::generate_impl(std::string const & suffix, expression
return stream.str(); return stream.str();
} }
elementwise_2d::elementwise_2d(uint32_t vwidth, uint32_t ls0, uint32_t ls1, elementwise_2d::elementwise_2d(unsigned int vwidth, unsigned int ls0, unsigned int ls1,
uint32_t ng0, uint32_t ng1, fetch_type fetch): unsigned int ng0, unsigned int ng1, fetch_type fetch):
base_impl(vwidth, ls0, ls1), ng0_(ng0), ng1_(ng1), fetch_(fetch) base_impl(vwidth, ls0, ls1), ng0_(ng0), ng1_(ng1), fetch_(fetch)
{} {}
@@ -121,7 +121,7 @@ void elementwise_2d::enqueue(driver::CommandQueue & /*queue*/, driver::Program c
driver::Kernel kernel(program, name.c_str()); driver::Kernel kernel(program, name.c_str());
driver::NDRange global(ls0_*ng0_, ls1_*ng1_); driver::NDRange global(ls0_*ng0_, ls1_*ng1_);
driver::NDRange local(ls0_, ls1_); driver::NDRange local(ls0_, ls1_);
uint32_t current_arg = 0; unsigned int current_arg = 0;
std::vector<int_t> MN = input_sizes(expressions); std::vector<int_t> MN = input_sizes(expressions);
kernel.setSizeArg(current_arg++, MN[0]); kernel.setSizeArg(current_arg++, MN[0]);
kernel.setSizeArg(current_arg++, MN[1]); kernel.setSizeArg(current_arg++, MN[1]);

View File

@@ -25,14 +25,14 @@
namespace isaac namespace isaac
{ {
kernel_generation_stream::kgenstream::kgenstream(std::ostringstream& oss,uint32_t const & tab_count) : kernel_generation_stream::kgenstream::kgenstream(std::ostringstream& oss,unsigned int const & tab_count) :
oss_(oss), tab_count_(tab_count) oss_(oss), tab_count_(tab_count)
{ } { }
int kernel_generation_stream::kgenstream::sync() int kernel_generation_stream::kgenstream::sync()
{ {
for (uint32_t i=0; i<tab_count_;++i) for (unsigned int i=0; i<tab_count_;++i)
oss_ << " "; oss_ << " ";
std::string next = str(); std::string next = str();
oss_ << next; oss_ << next;

View File

@@ -37,9 +37,9 @@ namespace isaac
namespace templates namespace templates
{ {
uint32_t gemm::lmem_usage(expression_tree const & expression) const unsigned int gemm::lmem_usage(expression_tree const & expression) const
{ {
uint32_t N = 0; unsigned int N = 0;
size_t llda = (A_trans_=='N')?mL_:kL_+1; size_t llda = (A_trans_=='N')?mL_:kL_+1;
size_t lnda = (A_trans_=='N')?kL_:mL_; size_t lnda = (A_trans_=='N')?kL_:mL_;
size_t lldb = (B_trans_=='T')?nL_:kL_+1; size_t lldb = (B_trans_=='T')?nL_:kL_+1;
@@ -49,13 +49,13 @@ namespace templates
return N*size_of(expression.dtype()); return N*size_of(expression.dtype());
} }
uint32_t gemm::registers_usage(expression_tree const & expression) const unsigned int gemm::registers_usage(expression_tree const & expression) const
{ {
uint32_t N = mS_ * nS_ + mS_ * kS_ + kS_ * nS_; unsigned int N = mS_ * nS_ + mS_ * kS_ + kS_ * nS_;
return N*size_of(expression.dtype()); return N*size_of(expression.dtype());
} }
uint32_t gemm::temporary_workspace(expression_tree const & expressions) const unsigned int gemm::temporary_workspace(expression_tree const & expressions) const
{ {
std::vector<int_t> MNK = input_sizes(expressions); std::vector<int_t> MNK = input_sizes(expressions);
int_t M = MNK[0]; int_t N = MNK[1]; int_t M = MNK[0]; int_t N = MNK[1];
@@ -85,8 +85,8 @@ namespace templates
if (Afetch_==FETCH_FROM_LOCAL) if (Afetch_==FETCH_FROM_LOCAL)
{ {
uint32_t bound1 = (A_trans_=='N')?kL_:mL_; unsigned int bound1 = (A_trans_=='N')?kL_:mL_;
uint32_t bound0 = (A_trans_=='N')?mL_:kL_; unsigned int bound0 = (A_trans_=='N')?mL_:kL_;
if (lf1_>0 && (bound1 % lf1_)> 0) if (lf1_>0 && (bound1 % lf1_)> 0)
return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE; return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
@@ -97,8 +97,8 @@ namespace templates
} }
if (Bfetch_==FETCH_FROM_LOCAL) if (Bfetch_==FETCH_FROM_LOCAL)
{ {
uint32_t bound1 = (B_trans_=='T')?kL_:nL_; unsigned int bound1 = (B_trans_=='T')?kL_:nL_;
uint32_t bound0 = (B_trans_=='T')?nL_:kL_; unsigned int bound0 = (B_trans_=='T')?nL_:kL_;
if (lf1_>0 && (bound1 % lf1_)> 0) if (lf1_>0 && (bound1 % lf1_)> 0)
return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE; return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
@@ -178,8 +178,8 @@ namespace templates
size_t lndb = (B_trans_=='T')?kL_:nL_; size_t lndb = (B_trans_=='T')?kL_:nL_;
stream << "$LOCAL " << sdtype << " lA[" << llda*lnda << "];" << std::endl; stream << "$LOCAL " << sdtype << " lA[" << llda*lnda << "];" << std::endl;
stream << "$LOCAL " << sdtype << " lB[" << lldb*lndb << "];" << std::endl; stream << "$LOCAL " << sdtype << " lB[" << lldb*lndb << "];" << std::endl;
uint32_t npA = mL_/(A_trans_=='N'?lf0_*vwidth_:lf1_); unsigned int npA = mL_/(A_trans_=='N'?lf0_*vwidth_:lf1_);
uint32_t npB = nL_/(B_trans_=='T'?lf0_*vwidth_:lf1_); unsigned int npB = nL_/(B_trans_=='T'?lf0_*vwidth_:lf1_);
stream << "$GLOBAL " << sdtype << "* Ai[" << npA << "];" << std::endl; stream << "$GLOBAL " << sdtype << "* Ai[" << npA << "];" << std::endl;
stream << "$GLOBAL " << sdtype << "* Bi[" << npB << "];" << std::endl; stream << "$GLOBAL " << sdtype << "* Bi[" << npB << "];" << std::endl;
stream << std::endl; stream << std::endl;
@@ -278,13 +278,13 @@ namespace templates
stream << "}" << std::endl; stream << "}" << std::endl;
stream << std::endl; stream << std::endl;
for(uint32_t i = 0 ; i < npA ; i++ ) for(unsigned int i = 0 ; i < npA ; i++ )
if (A_trans_=='N') if (A_trans_=='N')
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*lf0_*vwidth_) + " < M", "(int)((idT.x + " + to_string(i*lf0_*vwidth_) + ")" + ASTRIDE1 + ")", "0") << ";" << std::endl; stream << "Ai[" << i << "] += " << Select(backend, to_string(i*lf0_*vwidth_) + " < M", "(int)((idT.x + " + to_string(i*lf0_*vwidth_) + ")" + ASTRIDE1 + ")", "0") << ";" << std::endl;
else else
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*lf1_) + " < M", "(int)((idT.y + " + to_string(i*lf1_) + ")*lda)", "0") << ";" << std::endl; stream << "Ai[" << i << "] += " << Select(backend, to_string(i*lf1_) + " < M", "(int)((idT.y + " + to_string(i*lf1_) + ")*lda)", "0") << ";" << std::endl;
for(uint32_t i = 0 ; i < npB ; i++ ) for(unsigned int i = 0 ; i < npB ; i++ )
if (B_trans_=='T') if (B_trans_=='T')
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*lf0_*vwidth_) + " < N", "(int)((idT.x + " + to_string(i*lf0_*vwidth_) + ")" + BSTRIDE1 + ")", "0") << ";" << std::endl; stream << "Bi[" << i << "] += " << Select(backend, to_string(i*lf0_*vwidth_) + " < N", "(int)((idT.x + " + to_string(i*lf0_*vwidth_) + ")" + BSTRIDE1 + ")", "0") << ";" << std::endl;
else else
@@ -306,13 +306,13 @@ namespace templates
stream << "//Fetch A to local memory" << std::endl; stream << "//Fetch A to local memory" << std::endl;
if (A_trans_=='N') if (A_trans_=='N')
{ {
for(uint32_t k = 0; k < kL_; k += lf1_) for(unsigned int k = 0; k < kL_; k += lf1_)
for(uint32_t m = 0; m < mL_; m += lf0_*vwidth_) for(unsigned int m = 0; m < mL_; m += lf0_*vwidth_)
{ {
std::string mm = to_string(m/(vwidth_*lf0_)); std::string mm = to_string(m/(vwidth_*lf0_));
std::string kk = to_string(k); std::string kk = to_string(k);
if(last_iteration) if(last_iteration)
for(uint32_t s = 0 ; s < vwidth_ ; ++s) for(unsigned int s = 0 ; s < vwidth_ ; ++s)
stream << "ldsA[" << k*llda + m + s << "] = (condy" << k << " && " << s << "< M)? Ai[" << mm << "][" << k << "*lda + " << s << "] : 0;" << std::endl; stream << "ldsA[" << k*llda + m + s << "] = (condy" << k << " && " << s << "< M)? Ai[" << mm << "][" << k << "*lda + " << s << "] : 0;" << std::endl;
else else
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Ai[" + mm +"][" + kk + "*lda]"), "0", "ldsA + " + to_string(k*llda+m)) << ";" << std::endl; stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Ai[" + mm +"][" + kk + "*lda]"), "0", "ldsA + " + to_string(k*llda+m)) << ";" << std::endl;
@@ -320,13 +320,13 @@ namespace templates
} }
else else
{ {
for(uint32_t k = 0; k < kL_; k += lf0_*vwidth_) for(unsigned int k = 0; k < kL_; k += lf0_*vwidth_)
for(uint32_t m = 0; m < mL_; m += lf1_) for(unsigned int m = 0; m < mL_; m += lf1_)
{ {
std::string mm = to_string(m/lf1_); std::string mm = to_string(m/lf1_);
std::string kk = to_string(k); std::string kk = to_string(k);
if(last_iteration) if(last_iteration)
for(uint32_t s = 0 ; s < vwidth_ ; ++s) for(unsigned int s = 0 ; s < vwidth_ ; ++s)
stream << "ldsA[" << m*llda + k + s << "] = condx" << k + s << "? Ai[" << mm << "][" << k + s << ASTRIDE1 << "] : 0;" << std::endl; stream << "ldsA[" << m*llda + k + s << "] = condx" << k + s << "? Ai[" << mm << "][" << k + s << ASTRIDE1 << "] : 0;" << std::endl;
else else
@@ -337,13 +337,13 @@ namespace templates
stream << "//Fetch B to local memory" << std::endl; stream << "//Fetch B to local memory" << std::endl;
if (B_trans_=='T') if (B_trans_=='T')
{ {
for(uint32_t k = 0; k < kL_; k += lf1_) for(unsigned int k = 0; k < kL_; k += lf1_)
for(uint32_t n = 0; n < nL_; n += lf0_*vwidth_) for(unsigned int n = 0; n < nL_; n += lf0_*vwidth_)
{ {
std::string nn = to_string(n/(vwidth_*lf0_)); std::string nn = to_string(n/(vwidth_*lf0_));
std::string kk = to_string(k); std::string kk = to_string(k);
if(last_iteration) if(last_iteration)
for(uint32_t s = 0 ; s < vwidth_ ; ++s) for(unsigned int s = 0 ; s < vwidth_ ; ++s)
stream << "ldsB[" << k*lldb + n + s << "] = (condy" << k << " && " << s << "< N)? Bi[" << nn << "][" << kk << "*ldb +" << s << "] : 0;" << std::endl; stream << "ldsB[" << k*lldb + n + s << "] = (condy" << k << " && " << s << "< N)? Bi[" << nn << "][" << kk << "*ldb +" << s << "] : 0;" << std::endl;
else else
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Bi[" + nn +"][" + kk + "*ldb]"), "0", "ldsB + " + to_string(k*lldb+n)) << ";" << std::endl; stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Bi[" + nn +"][" + kk + "*ldb]"), "0", "ldsB + " + to_string(k*lldb+n)) << ";" << std::endl;
@@ -351,13 +351,13 @@ namespace templates
} }
else else
{ {
for(uint32_t k = 0; k < kL_; k += lf0_*vwidth_) for(unsigned int k = 0; k < kL_; k += lf0_*vwidth_)
for(uint32_t n = 0; n < nL_; n += lf1_) for(unsigned int n = 0; n < nL_; n += lf1_)
{ {
std::string nn = to_string(n/lf1_); std::string nn = to_string(n/lf1_);
std::string kk = to_string(k); std::string kk = to_string(k);
if(last_iteration) if(last_iteration)
for(uint32_t s = 0 ; s < vwidth_ ; ++s) for(unsigned int s = 0 ; s < vwidth_ ; ++s)
stream << "ldsB[" << n*lldb + k + s << "] = condx" << k + s << "? Bi[" << nn << "][" << k + s << BSTRIDE1 << "] : 0;" << std::endl; stream << "ldsB[" << n*lldb + k + s << "] = condx" << k + s << "? Bi[" << nn << "][" << k + s << BSTRIDE1 << "] : 0;" << std::endl;
else else
@@ -379,14 +379,14 @@ namespace templates
std::string bound = last_iteration?"K":tools::to_string(kL_); std::string bound = last_iteration?"K":tools::to_string(kL_);
size_t ks = last_iteration?1:kS_; size_t ks = last_iteration?1:kS_;
stream << "//Inner loop" << std::endl; stream << "//Inner loop" << std::endl;
stream << "for(uint32_t k = 0; k < " << bound << "; k+=" << ks << "){" << std::endl; stream << "for(unsigned int k = 0; k < " << bound << "; k+=" << ks << "){" << std::endl;
stream.inc_tab(); stream.inc_tab();
stream << "//Fetch A to registers" << std::endl; stream << "//Fetch A to registers" << std::endl;
stream << "#pragma unroll" << std::endl; stream << "#pragma unroll" << std::endl;
stream << "for(uint32_t kk = 0; kk < " << ks << "; kk++)" << std::endl; stream << "for(unsigned int kk = 0; kk < " << ks << "; kk++)" << std::endl;
stream << "#pragma unroll " << mS_/vwidth_ << std::endl; stream << "#pragma unroll " << mS_/vwidth_ << std::endl;
stream << "for(uint32_t mm = 0; mm < " << mS_/vwidth_ << "; mm++)" << std::endl; stream << "for(unsigned int mm = 0; mm < " << mS_/vwidth_ << "; mm++)" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
if(A_trans_=='N') if(A_trans_=='N')
@@ -396,7 +396,7 @@ namespace templates
if(vwidth_==1) if(vwidth_==1)
stream << "rA[kk][mm] = ldsA[k + mm*" << ls0_*llda << "+ kk" << "];" << std::endl; stream << "rA[kk][mm] = ldsA[k + mm*" << ls0_*llda << "+ kk" << "];" << std::endl;
else else
for(uint32_t s = 0 ; s < vwidth_ ; ++s) for(unsigned int s = 0 ; s < vwidth_ ; ++s)
stream << access_vector_type("rA[kk][mm]", s) << " = ldsA[k + (mm*" << vwidth_*ls0_ << " + " << s << ")*" << llda << "+ kk];" << std::endl; stream << access_vector_type("rA[kk][mm]", s) << " = ldsA[k + (mm*" << vwidth_*ls0_ << " + " << s << ")*" << llda << "+ kk];" << std::endl;
} }
@@ -405,9 +405,9 @@ namespace templates
stream << "//Fetch B to registers" << std::endl; stream << "//Fetch B to registers" << std::endl;
stream << "#pragma unroll " << ks << std::endl; stream << "#pragma unroll " << ks << std::endl;
stream << "for(uint32_t kk = 0; kk < " << ks << "; kk++)" << std::endl; stream << "for(unsigned int kk = 0; kk < " << ks << "; kk++)" << std::endl;
stream << "#pragma unroll " << nS_/vwidth_ << std::endl; stream << "#pragma unroll " << nS_/vwidth_ << std::endl;
stream << "for(uint32_t nn = 0; nn < " << nS_/vwidth_ << "; nn++)" << std::endl; stream << "for(unsigned int nn = 0; nn < " << nS_/vwidth_ << "; nn++)" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
if(B_trans_=='T') if(B_trans_=='T')
@@ -417,7 +417,7 @@ namespace templates
if(vwidth_==1) if(vwidth_==1)
stream << "rB[kk][nn] = ldsB[k" << " + nn*" << ls1_*lldb << "+ kk" << "];" << std::endl; stream << "rB[kk][nn] = ldsB[k" << " + nn*" << ls1_*lldb << "+ kk" << "];" << std::endl;
else else
for(uint32_t s = 0 ; s < vwidth_ ; ++s) for(unsigned int s = 0 ; s < vwidth_ ; ++s)
stream << access_vector_type("rB[kk][nn]", s) << " = ldsB[k" << " + (nn*" << vwidth_*ls1_ << " + " << s << ")*" << lldb << "+ kk];" << std::endl; stream << access_vector_type("rB[kk][nn]", s) << " = ldsB[k" << " + (nn*" << vwidth_*ls1_ << " + " << s << ")*" << lldb << "+ kk];" << std::endl;
} }
stream.dec_tab(); stream.dec_tab();
@@ -425,10 +425,10 @@ namespace templates
stream << "//FMA computations" << std::endl; stream << "//FMA computations" << std::endl;
stream << "#pragma unroll" << std::endl; stream << "#pragma unroll" << std::endl;
stream << "for(uint32_t kk = 0 ; kk < " << ks << "; ++kk){" << std::endl; stream << "for(unsigned int kk = 0 ; kk < " << ks << "; ++kk){" << std::endl;
stream.inc_tab(); stream.inc_tab();
for(uint32_t nn=0; nn < nS_; ++nn) for(unsigned int nn=0; nn < nS_; ++nn)
for(uint32_t mm=0; mm < mS_; ++mm){ for(unsigned int mm=0; mm < mS_; ++mm){
string res_str, lhs_str, rhs_str; string res_str, lhs_str, rhs_str;
res_str = "rC[" + to_string(mm) + "][" + to_string(nn) + "]"; res_str = "rC[" + to_string(mm) + "][" + to_string(nn) + "]";
if (vwidth_==1) if (vwidth_==1)
@@ -449,18 +449,18 @@ namespace templates
//Increment A pointers to global memory //Increment A pointers to global memory
if (A_trans_=='N') if (A_trans_=='N')
for(uint32_t i = 0 ; i < npA ; ++i) for(unsigned int i = 0 ; i < npA ; ++i)
stream << "Ai[" << i << "] += " << kL_ << "*lda;" << std::endl; stream << "Ai[" << i << "] += " << kL_ << "*lda;" << std::endl;
else else
for(uint32_t i = 0 ; i < npA ; ++i) for(unsigned int i = 0 ; i < npA ; ++i)
stream << "Ai[" << i << "] += " << kL_ << ASTRIDE1 << ";" << std::endl; stream << "Ai[" << i << "] += " << kL_ << ASTRIDE1 << ";" << std::endl;
//Increment B pointers to global memory //Increment B pointers to global memory
if (B_trans_=='T') if (B_trans_=='T')
for(uint32_t i = 0 ; i < npB ; ++i) for(unsigned int i = 0 ; i < npB ; ++i)
stream << "Bi[" << i << "] += " << kL_ << "*ldb;" << std::endl; stream << "Bi[" << i << "] += " << kL_ << "*ldb;" << std::endl;
else else
for(uint32_t i = 0 ; i < npB ; ++i) for(unsigned int i = 0 ; i < npB ; ++i)
stream << "Bi[" << i << "] += " << kL_ << BSTRIDE1 << ";" << std::endl; stream << "Bi[" << i << "] += " << kL_ << BSTRIDE1 << ";" << std::endl;
}; };
fetch_to_lds(false); fetch_to_lds(false);
@@ -471,15 +471,15 @@ namespace templates
if(A_trans_=='N' || B_trans_=='T') if(A_trans_=='N' || B_trans_=='T')
{ {
stream << "int Ky = K - idT.y;" << std::endl; stream << "int Ky = K - idT.y;" << std::endl;
for(uint32_t k = 0; k < kL_; k += lf1_) for(unsigned int k = 0; k < kL_; k += lf1_)
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl; stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
} }
if(A_trans_=='T' || B_trans_=='N') if(A_trans_=='T' || B_trans_=='N')
{ {
stream << "int Kx = K - idT.x;" << std::endl; stream << "int Kx = K - idT.x;" << std::endl;
for(uint32_t k = 0 ; k < kL_ ; k += lf0_*vwidth_) for(unsigned int k = 0 ; k < kL_ ; k += lf0_*vwidth_)
for(uint32_t s = 0 ; s < vwidth_ ; ++s) for(unsigned int s = 0 ; s < vwidth_ ; ++s)
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl; stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
} }
fetch_to_lds(true); fetch_to_lds(true);
@@ -510,13 +510,13 @@ namespace templates
stream << "N -= ids.y;" << std::endl; stream << "N -= ids.y;" << std::endl;
stream << "N -= ids.w*" << vwidth_ << ";" << std::endl; stream << "N -= ids.w*" << vwidth_ << ";" << std::endl;
for(uint32_t n=0; n < nS_; ++n) for(unsigned int n=0; n < nS_; ++n)
{ {
string Cj = to_string((n/vwidth_)*(ls1_*vwidth_) + n%vwidth_); string Cj = to_string((n/vwidth_)*(ls1_*vwidth_) + n%vwidth_);
stream << "if(" << Cj << " >= N) return;" << std::endl; stream << "if(" << Cj << " >= N) return;" << std::endl;
for(uint32_t m=0; m < mS_; ++m) for(unsigned int m=0; m < mS_; ++m)
stream << "rC[" << m << "][" << n << "] *= alpha;" << std::endl; stream << "rC[" << m << "][" << n << "] *= alpha;" << std::endl;
for(uint32_t m=0; m < mS_; ++m) for(unsigned int m=0; m < mS_; ++m)
{ {
string Ci = to_string((m/vwidth_)*(ls0_*vwidth_) + m%vwidth_); string Ci = to_string((m/vwidth_)*(ls0_*vwidth_) + m%vwidth_);
stream << "if(" << Ci << "< M) "; stream << "if(" << Ci << "< M) ";
@@ -548,14 +548,14 @@ namespace templates
stream.inc_tab(); stream.inc_tab();
stream << "C += Cstart;" << std::endl; stream << "C += Cstart;" << std::endl;
stream << "for(uint32_t i = $GLOBAL_IDX_0 ; i < M ; i += $GLOBAL_SIZE_0)" << std::endl; stream << "for(unsigned int i = $GLOBAL_IDX_0 ; i < M ; i += $GLOBAL_SIZE_0)" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
stream << "for(uint32_t j = $GLOBAL_IDX_1 ; j < N ; j += $GLOBAL_SIZE_1)" << std::endl; stream << "for(unsigned int j = $GLOBAL_IDX_1 ; j < N ; j += $GLOBAL_SIZE_1)" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
stream << sdtype << " acc = 0;" << std::endl; stream << sdtype << " acc = 0;" << std::endl;
stream << "for(uint32_t k = 0 ; k < D ; k++)" << std::endl; stream << "for(unsigned int k = 0 ; k < D ; k++)" << std::endl;
stream.inc_tab(); stream.inc_tab();
stream << "acc += Z[i + j*Zld + k*Zld*N];" << std::endl; stream << "acc += Z[i + j*Zld + k*Zld*N];" << std::endl;
stream.dec_tab(); stream.dec_tab();
@@ -597,7 +597,7 @@ namespace templates
driver::NDRange local(ls0_, ls1_, 1); driver::NDRange local(ls0_, ls1_, 1);
driver::NDRange global(align(align(M,mS_)/mS_, ls0_), align(align(N,nS_)/nS_, ls1_), depth_); driver::NDRange global(align(align(M,mS_)/mS_, ls0_), align(align(N,nS_)/nS_, ls1_), depth_);
uint32_t current_arg = 0; unsigned int current_arg = 0;
driver::Buffer& workspace = driver::backend::workspaces::get(options.queue(queue.context())); driver::Buffer& workspace = driver::backend::workspaces::get(options.queue(queue.context()));
gemm.setSizeArg(current_arg++, M); gemm.setSizeArg(current_arg++, M);
@@ -644,7 +644,7 @@ namespace templates
if(depth_ > 1) if(depth_ > 1)
{ {
uint32_t current_arg = 0; unsigned int current_arg = 0;
driver::Kernel reduce(program, reduce_name.c_str()); driver::Kernel reduce(program, reduce_name.c_str());
driver::NDRange local(ls0_, ls1_); driver::NDRange local(ls0_, ls1_);
driver::NDRange global(align(M, ls0_), align(N, ls1_)); driver::NDRange global(align(M, ls0_), align(N, ls1_));
@@ -677,7 +677,7 @@ namespace templates
return {M, N, K}; return {M, N, K};
} }
gemm::gemm(uint32_t vwidth gemm::gemm(unsigned int vwidth
,int_t ls0, int_t kL, int_t ls1, int_t D ,int_t ls0, int_t kL, int_t ls1, int_t D
,int_t ms, int_t ks, int_t ns ,int_t ms, int_t ks, int_t ns
,fetch_type Afetch , fetch_type Bfetch ,fetch_type Afetch , fetch_type Bfetch
@@ -715,7 +715,7 @@ namespace templates
} }
// //
gemm_nn::gemm_nn(uint32_t vwidth gemm_nn::gemm_nn(unsigned int vwidth
, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns , int_t ms, int_t ks, int_t ns
, fetch_type Afetch , fetch_type Bfetch , fetch_type Afetch , fetch_type Bfetch
@@ -725,7 +725,7 @@ namespace templates
} }
// //
gemm_tn::gemm_tn(uint32_t vwidth gemm_tn::gemm_tn(unsigned int vwidth
, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns , int_t ms, int_t ks, int_t ns
, fetch_type Afetch , fetch_type Bfetch , fetch_type Afetch , fetch_type Bfetch
@@ -734,7 +734,7 @@ namespace templates
{ } { }
// //
gemm_nt::gemm_nt(uint32_t vwidth gemm_nt::gemm_nt(unsigned int vwidth
, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns , int_t ms, int_t ks, int_t ns
, fetch_type Afetch , fetch_type Bfetch , fetch_type Afetch , fetch_type Bfetch
@@ -743,7 +743,7 @@ namespace templates
{ } { }
// //
gemm_tt::gemm_tt(uint32_t vwidth gemm_tt::gemm_tt(unsigned int vwidth
, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns , int_t ms, int_t ks, int_t ns
, fetch_type Afetch , fetch_type Bfetch , fetch_type Afetch , fetch_type Bfetch

View File

@@ -36,7 +36,7 @@ namespace isaac
namespace templates namespace templates
{ {
uint32_t reduce_1d::lmem_usage(expression_tree const & x) const unsigned int reduce_1d::lmem_usage(expression_tree const & x) const
{ {
return ls0_*size_of(x.dtype()); return ls0_*size_of(x.dtype());
} }
@@ -48,18 +48,18 @@ int reduce_1d::is_invalid_impl(driver::Device const &, expression_tree const &)
return TEMPLATE_VALID; return TEMPLATE_VALID;
} }
uint32_t reduce_1d::temporary_workspace(expression_tree const &) const unsigned int reduce_1d::temporary_workspace(expression_tree const &) const
{ {
if(ng_ > 1) if(ng_ > 1)
return ng_; return ng_;
return 0; return 0;
} }
inline void reduce_1d::reduce_1d_local_memory(kernel_generation_stream & stream, uint32_t size, std::vector<symbolic::reduce_1d*> exprs, inline void reduce_1d::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<symbolic::reduce_1d*> exprs,
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type) const std::string const & buf_str, std::string const & buf_value_str, driver::backend_type) const
{ {
stream << "#pragma unroll" << std::endl; stream << "#pragma unroll" << std::endl;
stream << "for(uint32_t stride = " << size/2 << "; stride > 0; stride /=2)" << std::endl; stream << "for(unsigned int stride = " << size/2 << "; stride > 0; stride /=2)" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
stream << "$LOCAL_BARRIER;" << std::endl; stream << "$LOCAL_BARRIER;" << std::endl;
@@ -91,7 +91,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
auto unroll_tmp = [&]() auto unroll_tmp = [&]()
{ {
uint32_t offset = 0; unsigned int offset = 0;
for(symbolic::reduce_1d* rd: reductions) for(symbolic::reduce_1d* rd: reductions)
{ {
numeric_type dtype = tree.dtype(); numeric_type dtype = tree.dtype();
@@ -126,10 +126,10 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
stream.inc_tab(); stream.inc_tab();
unroll_tmp(); unroll_tmp();
//Declare //Declare
stream << "uint32_t lid = $LOCAL_IDX_0;" << std::endl; stream << "unsigned int lid = $LOCAL_IDX_0;" << std::endl;
stream << "uint32_t gid = $GLOBAL_IDX_0;" << std::endl; stream << "unsigned int gid = $GLOBAL_IDX_0;" << std::endl;
stream << "uint32_t gpid = $GROUP_IDX_0;" << std::endl; stream << "unsigned int gpid = $GROUP_IDX_0;" << std::endl;
stream << "uint32_t gsize = $GLOBAL_SIZE_0;" << std::endl; stream << "unsigned int gsize = $GLOBAL_SIZE_0;" << std::endl;
for(symbolic::reduce_1d* rd: reductions) for(symbolic::reduce_1d* rd: reductions)
{ {
@@ -137,8 +137,8 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
{ {
stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(ls0_) + "];") << std::endl; stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(ls0_) + "];") << std::endl;
stream << rd->process("#scalartype #name_acc_value = " + neutral_element(rd->op(), backend, "#scalartype") + ";") << std::endl; stream << rd->process("#scalartype #name_acc_value = " + neutral_element(rd->op(), backend, "#scalartype") + ";") << std::endl;
stream << rd->process("$LOCAL uint32_t #name_buf[" + tools::to_string(ls0_) + "];") << std::endl; stream << rd->process("$LOCAL unsigned int #name_buf[" + tools::to_string(ls0_) + "];") << std::endl;
stream << rd->process("uint32_t #name_acc = 0;") << std::endl; stream << rd->process("unsigned int #name_acc = 0;") << std::endl;
} }
else else
{ {
@@ -146,7 +146,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
stream << rd->process("#scalartype #name_acc = " + neutral_element(rd->op(), backend, "#scalartype") + ";") << std::endl; stream << rd->process("#scalartype #name_acc = " + neutral_element(rd->op(), backend, "#scalartype") + ";") << std::endl;
} }
} }
element_wise_loop_1D(stream, fetch_, vwidth_, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t vwidth) element_wise_loop_1D(stream, fetch_, vwidth_, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](unsigned int vwidth)
{ {
std::string dtype = append_width("#scalartype",vwidth); std::string dtype = append_width("#scalartype",vwidth);
//Fetch vector entry //Fetch vector entry
@@ -157,7 +157,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
stream << leaf->process(dtype + " #name = " + append_width("loadv", vwidth) + "(i);") << std::endl; stream << leaf->process(dtype + " #name = " + append_width("loadv", vwidth) + "(i);") << std::endl;
//Update accumulators //Update accumulators
for (symbolic::reduce_1d* rd : reductions) for (symbolic::reduce_1d* rd : reductions)
for (uint32_t s = 0; s < vwidth; ++s) for (unsigned int s = 0; s < vwidth; ++s)
{ {
std::string value = rd->lhs()->evaluate({{"leaf", access_vector_type("#name", s, vwidth)}}); std::string value = rd->lhs()->evaluate({{"leaf", access_vector_type("#name", s, vwidth)}});
if (is_indexing(rd->op().type)) if (is_indexing(rd->op().type))
@@ -199,14 +199,14 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
stream.inc_tab(); stream.inc_tab();
unroll_tmp(); unroll_tmp();
//Declarations //Declarations
stream << "uint32_t lid = $LOCAL_IDX_0;" << std::endl; stream << "unsigned int lid = $LOCAL_IDX_0;" << std::endl;
stream << "uint32_t lsize = $LOCAL_SIZE_0;" << std::endl; stream << "unsigned int lsize = $LOCAL_SIZE_0;" << std::endl;
for (symbolic::reduce_1d* rd: reductions) for (symbolic::reduce_1d* rd: reductions)
{ {
if (is_indexing(rd->op().type)) if (is_indexing(rd->op().type))
{ {
stream << rd->process("$LOCAL uint32_t #name_buf[" + tools::to_string(ls0_) + "];"); stream << rd->process("$LOCAL unsigned int #name_buf[" + tools::to_string(ls0_) + "];");
stream << rd->process("uint32_t #name_acc = 0;") << std::endl; stream << rd->process("unsigned int #name_acc = 0;") << std::endl;
stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(ls0_) + "];") << std::endl; stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(ls0_) + "];") << std::endl;
stream << rd->process("#scalartype #name_acc_value = " + neutral_element(rd->op(), backend, "#scalartype") + ";"); stream << rd->process("#scalartype #name_acc_value = " + neutral_element(rd->op(), backend, "#scalartype") + ";");
} }
@@ -217,7 +217,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
} }
} }
//Private reduction //Private reduction
stream << "for(uint32_t i = lid; i < " << ng_ << "; i += lsize)" << std::endl; stream << "for(unsigned int i = lid; i < " << ng_ << "; i += lsize)" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
for (symbolic::reduce_1d* rd: reductions) for (symbolic::reduce_1d* rd: reductions)
@@ -249,7 +249,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
return stream.str(); return stream.str();
} }
reduce_1d::reduce_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch): reduce_1d::reduce_1d(unsigned int vwidth, unsigned int ls, unsigned int ng, fetch_type fetch):
base_impl(vwidth,ls,1), ng_(ng), fetch_(fetch) base_impl(vwidth,ls,1), ng_(ng), fetch_(fetch)
{} {}
@@ -280,13 +280,13 @@ void reduce_1d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
//Arguments //Arguments
for (auto & kernel : kernels) for (auto & kernel : kernels)
{ {
uint32_t n_arg = 0; unsigned int n_arg = 0;
kernel.setSizeArg(n_arg++, size); kernel.setSizeArg(n_arg++, size);
kernel.setArg(n_arg++, driver::backend::workspaces::get(queue)); kernel.setArg(n_arg++, driver::backend::workspaces::get(queue));
symbolic::set_arguments(x, kernel, n_arg); symbolic::set_arguments(x, kernel, n_arg);
} }
for (uint32_t k = 0; k < 2; k++) for (unsigned int k = 0; k < 2; k++)
control.execution_options().enqueue(program.context(), kernels[k], global[k], local[k]); control.execution_options().enqueue(program.context(), kernels[k], global[k], local[k]);
queue.synchronize(); queue.synchronize();
} }

View File

@@ -46,12 +46,12 @@ int reduce_2d::is_invalid_impl(driver::Device const &, expression_tree const &)
return TEMPLATE_VALID; return TEMPLATE_VALID;
} }
uint32_t reduce_2d::lmem_usage(const expression_tree&) const unsigned int reduce_2d::lmem_usage(const expression_tree&) const
{ {
return (ls0_+1)*ls1_; return (ls0_+1)*ls1_;
} }
uint32_t reduce_2d::temporary_workspace(expression_tree const & expressions) const unsigned int reduce_2d::temporary_workspace(expression_tree const & expressions) const
{ {
std::vector<int_t> MN = input_sizes(expressions); std::vector<int_t> MN = input_sizes(expressions);
int_t M = MN[0]; int_t M = MN[0];
@@ -74,12 +74,12 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
name[0] += suffix; name[0] += suffix;
name[1] += suffix; name[1] += suffix;
uint32_t ldls = ls0_; unsigned int ldls = ls0_;
std::string ls0ldstr = to_string(ldls); std::string ls0ldstr = to_string(ldls);
auto unroll_tmp = [&]() auto unroll_tmp = [&]()
{ {
uint32_t offset = 0; unsigned int offset = 0;
for (symbolic::reduce_2d* rd : reductions) for (symbolic::reduce_2d* rd : reductions)
{ {
numeric_type dtype = tree.dtype(); numeric_type dtype = tree.dtype();
@@ -121,7 +121,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
std::ostringstream upper; std::ostringstream upper;
upper << "(M +" << ls1_ - 1 << ")/" << ls1_ << "*" << ls1_; upper << "(M +" << ls1_ - 1 << ")/" << ls1_ << "*" << ls1_;
element_wise_loop_1D(stream, fetch_, (reduction_type_==REDUCE_ROWS)?1:1, "r", upper.str(), "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device, [&](uint32_t cwidth) element_wise_loop_1D(stream, fetch_, (reduction_type_==REDUCE_ROWS)?1:1, "r", upper.str(), "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device, [&](unsigned int cwidth)
{ {
//Declare Buffers //Declare Buffers
for (symbolic::reduce_2d* rd : reductions) for (symbolic::reduce_2d* rd : reductions)
@@ -136,7 +136,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
stream << "if (r < M)" << std::endl; stream << "if (r < M)" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
element_wise_loop_1D(stream, fetch_, (reduction_type_==REDUCE_COLUMNS)?vwidth_:1, "c", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t rwidth) element_wise_loop_1D(stream, fetch_, (reduction_type_==REDUCE_COLUMNS)?vwidth_:1, "c", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](unsigned int rwidth)
{ {
std::string rdtype = append_width("#scalartype", rwidth); std::string rdtype = append_width("#scalartype", rwidth);
std::string cdtype = append_width("#scalartype", cwidth); std::string cdtype = append_width("#scalartype", cwidth);
@@ -152,7 +152,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
} }
//Compute //Compute
for (symbolic::reduce_2d* rd : reductions) for (symbolic::reduce_2d* rd : reductions)
for (uint32_t s = 0; s < rwidth; ++s){ for (unsigned int s = 0; s < rwidth; ++s){
std::string value = rd->lhs()->evaluate({{"leaf", access_vector_type("#name", s, rwidth)}}); std::string value = rd->lhs()->evaluate({{"leaf", access_vector_type("#name", s, rwidth)}});
if (is_indexing(rd->op().type)) if (is_indexing(rd->op().type))
compute_index_reduce_1d(stream, rd->process("#name_acc"), "c*"+to_string(rwidth) + to_string(s), rd->process("#name_acc_value"), value, rd->op()); compute_index_reduce_1d(stream, rd->process("#name_acc"), "c*"+to_string(rwidth) + to_string(s), rd->process("#name_acc_value"), value, rd->op());
@@ -276,7 +276,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
return stream.str(); return stream.str();
} }
reduce_2d::reduce_2d(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch, reduce_2d::reduce_2d(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, fetch_type fetch,
operation_type_family rtype) : operation_type_family rtype) :
base_impl(vwidth, ls0, ls1), ng0_(ng0), ng1_(ng1), fetch_(fetch), base_impl(vwidth, ls0, ls1), ng0_(ng0), ng1_(ng1), fetch_(fetch),
reduction_type_(rtype){ } reduction_type_(rtype){ }
@@ -300,16 +300,16 @@ void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
name[0] += suffix; name[0] += suffix;
name[1] += suffix; name[1] += suffix;
uint32_t nk = (ng0_==1)?1:2; unsigned int nk = (ng0_==1)?1:2;
std::vector<driver::Kernel> kernels; std::vector<driver::Kernel> kernels;
for(uint32_t k = 0 ; k < nk ; ++k) for(unsigned int k = 0 ; k < nk ; ++k)
kernels.push_back(driver::Kernel(program, name[k].c_str())); kernels.push_back(driver::Kernel(program, name[k].c_str()));
for(uint32_t k = 0 ; k < nk ; ++k) for(unsigned int k = 0 ; k < nk ; ++k)
{ {
driver::Kernel & kernel = kernels[k]; driver::Kernel & kernel = kernels[k];
uint32_t n_arg = 0; unsigned int n_arg = 0;
int_t M = MN[0]; int_t M = MN[0];
int_t N = MN[1]; int_t N = MN[1];
kernel.setSizeArg(n_arg++, M); kernel.setSizeArg(n_arg++, M);
@@ -321,14 +321,14 @@ void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
//NDRange //NDRange
driver::NDRange global[2] = { driver::NDRange(ls0_*ng0_, ls1_*ng1_), driver::NDRange(ls0_, ls1_*ng1_) }; driver::NDRange global[2] = { driver::NDRange(ls0_*ng0_, ls1_*ng1_), driver::NDRange(ls0_, ls1_*ng1_) };
driver::NDRange local[2] = { driver::NDRange(ls0_, ls1_), driver::NDRange(ls0_, ls1_) }; driver::NDRange local[2] = { driver::NDRange(ls0_, ls1_), driver::NDRange(ls0_, ls1_) };
for(uint32_t i = 0 ; i < nk ; ++i) for(unsigned int i = 0 ; i < nk ; ++i)
control.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]); control.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
} }
reduce_2d_rows::reduce_2d_rows(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, reduce_2d_rows::reduce_2d_rows(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1,
fetch_type fetch): reduce_2d(vwidth, ls0, ls1, ng0, ng1, fetch, REDUCE_ROWS) {} fetch_type fetch): reduce_2d(vwidth, ls0, ls1, ng0, ng1, fetch, REDUCE_ROWS) {}
reduce_2d_cols::reduce_2d_cols(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, reduce_2d_cols::reduce_2d_cols(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1,
fetch_type fetch): reduce_2d(vwidth, ls0, ls1, ng0, ng1, fetch, REDUCE_COLUMNS) {} fetch_type fetch): reduce_2d(vwidth, ls0, ls1, ng0, ng1, fetch, REDUCE_COLUMNS) {}

View File

@@ -96,7 +96,7 @@ std::string hash(expression_tree const & tree)
} }
//Set arguments //Set arguments
void set_arguments(expression_tree const & tree, driver::Kernel & kernel, uint32_t& current_arg) void set_arguments(expression_tree const & tree, driver::Kernel & kernel, unsigned int& current_arg)
{ {
driver::backend_type backend = tree.context().backend(); driver::backend_type backend = tree.context().backend();