Code Quality: Removed useless/buggy fetch_type tuning parameter

This commit is contained in:
Philippe Tillet
2016-10-05 19:10:12 -04:00
parent 52fc41461a
commit 87cb0ab375
19 changed files with 77 additions and 182 deletions

View File

@@ -40,13 +40,6 @@ namespace isaac
namespace templates namespace templates
{ {
enum fetch_type
{
FETCH_FROM_LOCAL,
FETCH_FROM_GLOBAL_STRIDED,
FETCH_FROM_GLOBAL_CONTIGUOUS
};
//Error codes //Error codes
static const int TEMPLATE_VALID = 0; static const int TEMPLATE_VALID = 0;
static const int TEMPLATE_LOCAL_MEMORY_OVERFLOW = -1; static const int TEMPLATE_LOCAL_MEMORY_OVERFLOW = -1;
@@ -105,6 +98,7 @@ class parameterized_base : public base
{ {
private: private:
virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const; virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const;
public: public:
parameterized_base(unsigned int _vwidth, int_t _ls0, int_t _ls1); parameterized_base(unsigned int _vwidth, int_t _ls0, int_t _ls1);
unsigned int ls0() const; unsigned int ls0() const;

View File

@@ -32,16 +32,14 @@ namespace templates
class elementwise_1d : public parameterized_base class elementwise_1d : public parameterized_base
{ {
private: private:
virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const;
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & symbols) const; std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & symbols) const;
public: public:
elementwise_1d(unsigned int vwidth, unsigned int ls, unsigned int ng, fetch_type fetch); elementwise_1d(unsigned int vwidth, unsigned int ls, unsigned int ng);
std::vector<int_t> input_sizes(expression_tree const & expressions) const; std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
expression_type type() const; expression_type type() const;
private: private:
unsigned int ng_; unsigned int ng_;
fetch_type fetch_;
}; };
} }

View File

@@ -36,14 +36,13 @@ private:
int is_invalid_impl(driver::Device const &, expression_tree const &) const; int is_invalid_impl(driver::Device const &, expression_tree const &) const;
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const; std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const;
public: public:
elementwise_2d(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, fetch_type fetch); elementwise_2d(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1);
std::vector<int_t> input_sizes(expression_tree const & expressions) const; std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
expression_type type() const; expression_type type() const;
private: private:
unsigned int ng0_; unsigned int ng0_;
unsigned int ng1_; unsigned int ng1_;
fetch_type fetch_;
}; };
} }

View File

@@ -59,8 +59,8 @@ private:
public: public:
gemm(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D gemm(unsigned int simd, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch , int_t ms, int_t ks, int_t ns, int_t lf0, int_t lf1
, int_t lf0, int_t lf1, char A_trans, char B_trans); , char A_trans, char B_trans);
std::vector<int_t> input_sizes(expression_tree const & expressions) const; std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & h); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & h);
expression_type type() const; expression_type type() const;
@@ -75,9 +75,6 @@ private:
unsigned int kS_; unsigned int kS_;
unsigned int nS_; unsigned int nS_;
fetch_type Afetch_;
fetch_type Bfetch_;
unsigned int lf0_; unsigned int lf0_;
unsigned int lf1_; unsigned int lf1_;
@@ -93,16 +90,14 @@ class gemm_nn : public gemm
{ {
public: public:
gemm_nn(unsigned int vwidth, int_t ls0, int_t KL, int_t ls1, int_t D gemm_nn(unsigned int vwidth, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch , int_t ms, int_t ks, int_t ns, int_t lf0, int_t lf1);
, int_t lf0, int_t lf1);
}; };
class gemm_tn : public gemm class gemm_tn : public gemm
{ {
public: public:
gemm_tn(unsigned int vwidth, int_t ls0, int_t KL, int_t ls1, int_t D gemm_tn(unsigned int vwidth, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch , int_t ms, int_t ks, int_t ns, int_t lf0, int_t lf1);
, int_t lf0, int_t lf1);
}; };
@@ -110,8 +105,7 @@ class gemm_nt : public gemm
{ {
public: public:
gemm_nt(unsigned int vwidth, int_t ls0, int_t KL, int_t ls1, int_t D gemm_nt(unsigned int vwidth, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch , int_t ms, int_t ks, int_t ns, int_t lf0, int_t lf1);
, int_t lf0, int_t lf1);
}; };
@@ -119,8 +113,7 @@ class gemm_tt : public gemm
{ {
public: public:
gemm_tt(unsigned int vwidth, int_t ls0, int_t KL, int_t ls1, int_t D gemm_tt(unsigned int vwidth, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch , int_t ms, int_t ks, int_t ns, int_t lf0, int_t lf1);
, int_t lf0, int_t lf1);
}; };
} }

View File

@@ -33,21 +33,19 @@ class reduce_1d : public parameterized_base
{ {
private: private:
unsigned int lmem_usage(expression_tree const & expressions) const; unsigned int lmem_usage(expression_tree const & expressions) const;
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
unsigned int temporary_workspace(expression_tree const & expressions) const; unsigned int temporary_workspace(expression_tree const & expressions) const;
inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<symbolic::reduce_1d*> exprs, inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<symbolic::reduce_1d*> exprs,
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const; std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const;
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const; std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const;
public: public:
reduce_1d(unsigned int vwidth, unsigned int ls, unsigned int ng, fetch_type fetch); reduce_1d(unsigned int vwidth, unsigned int ls, unsigned int ng);
std::vector<int_t> input_sizes(expression_tree const & expressions) const; std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
expression_type type() const; expression_type type() const;
private: private:
unsigned int ng_; unsigned int ng_;
fetch_type fetch_;
std::vector< driver::Buffer > tmp_; std::vector< driver::Buffer > tmp_;
std::vector< driver::Buffer > tmpidx_; std::vector< driver::Buffer > tmpidx_;
}; };

View File

@@ -35,9 +35,8 @@ namespace templates
class reduce_2d : public parameterized_base class reduce_2d : public parameterized_base
{ {
protected: protected:
reduce_2d(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, fetch_type fetch, operation_type_family); reduce_2d(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, operation_type_family);
private: private:
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
unsigned int lmem_usage(expression_tree const &) const; unsigned int lmem_usage(expression_tree const &) const;
unsigned int temporary_workspace(expression_tree const & expressions) const; unsigned int temporary_workspace(expression_tree const & expressions) const;
std::string generate_impl(std::string const & suffix, expression_tree const &, driver::Device const & device, symbolic::symbols_table const &) const; std::string generate_impl(std::string const & suffix, expression_tree const &, driver::Device const & device, symbolic::symbols_table const &) const;
@@ -48,20 +47,19 @@ public:
private: private:
unsigned int ng0_; unsigned int ng0_;
unsigned int ng1_; unsigned int ng1_;
fetch_type fetch_;
operation_type_family reduction_type_; operation_type_family reduction_type_;
}; };
class reduce_2d_rows : public reduce_2d class reduce_2d_rows : public reduce_2d
{ {
public: public:
reduce_2d_rows(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, fetch_type fetch); reduce_2d_rows(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1);
}; };
class reduce_2d_cols : public reduce_2d class reduce_2d_cols : public reduce_2d
{ {
public: public:
reduce_2d_cols(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, fetch_type fetch); reduce_2d_cols(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1);
}; };
} }

View File

@@ -36,13 +36,6 @@ namespace isaac
namespace templates namespace templates
{ {
int elementwise_1d::is_invalid_impl(driver::Device const &, expression_tree const &) const
{
if (fetch_==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
expression_type elementwise_1d::type() const expression_type elementwise_1d::type() const
{ return ELEMENTWISE_1D; } { return ELEMENTWISE_1D; }
@@ -78,7 +71,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
stream.inc_tab(); stream.inc_tab();
} }
element_wise_loop_1D(stream, fetch_, vwidth_, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](unsigned int vwidth) element_wise_loop_1D(stream, vwidth_, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", [&](unsigned int vwidth)
{ {
std::string dtype = append_width("#scalartype",vwidth); std::string dtype = append_width("#scalartype",vwidth);
@@ -113,8 +106,8 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
return stream.str(); return stream.str();
} }
elementwise_1d::elementwise_1d(unsigned int vwidth, unsigned int ls, unsigned int ng, fetch_type fetch): elementwise_1d::elementwise_1d(unsigned int vwidth, unsigned int ls, unsigned int ng):
parameterized_base(vwidth,ls,1), ng_(ng), fetch_(fetch) parameterized_base(vwidth,ls,1), ng_(ng)
{} {}

View File

@@ -37,8 +37,6 @@ int elementwise_2d::is_invalid_impl(driver::Device const &, expression_tree cons
{ {
if (vwidth_>1) if (vwidth_>1)
return TEMPLATE_INVALID_SIMD_WIDTH; return TEMPLATE_INVALID_SIMD_WIDTH;
if(fetch_==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID; return TEMPLATE_VALID;
} }
@@ -69,8 +67,8 @@ std::string elementwise_2d::generate_impl(std::string const & suffix, expression
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
element_wise_loop_1D(stream, fetch_, 1, "i", "M", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](unsigned int){ element_wise_loop_1D(stream, 1, "i", "M", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", [&](unsigned int){
element_wise_loop_1D(stream, fetch_, 1, "j", "N", "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device, [&](unsigned int){ element_wise_loop_1D(stream, 1, "j", "N", "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", [&](unsigned int){
//Declares register to store results //Declares register to store results
for(symbolic::leaf* sym: symbolic::extract<symbolic::leaf>(tree, symbols, assigned_left, false)) for(symbolic::leaf* sym: symbolic::extract<symbolic::leaf>(tree, symbols, assigned_left, false))
stream << sym->process("#scalartype #name;") << std::endl; stream << sym->process("#scalartype #name;") << std::endl;
@@ -96,8 +94,8 @@ std::string elementwise_2d::generate_impl(std::string const & suffix, expression
} }
elementwise_2d::elementwise_2d(unsigned int vwidth, unsigned int ls0, unsigned int ls1, elementwise_2d::elementwise_2d(unsigned int vwidth, unsigned int ls0, unsigned int ls1,
unsigned int ng0, unsigned int ng1, fetch_type fetch): unsigned int ng0, unsigned int ng1):
parameterized_base(vwidth, ls0, ls1), ng0_(ng0), ng1_(ng1), fetch_(fetch) parameterized_base(vwidth, ls0, ls1), ng0_(ng0), ng1_(ng1)
{} {}
std::vector<int_t> elementwise_2d::input_sizes(expression_tree const & expression) const{ std::vector<int_t> elementwise_2d::input_sizes(expression_tree const & expression) const{

View File

@@ -153,9 +153,6 @@ unsigned int gemm::temporary_workspace(expression_tree const & expressions) cons
int gemm::is_invalid_impl(driver::Device const &, expression_tree const &) const int gemm::is_invalid_impl(driver::Device const &, expression_tree const &) const
{ {
if(Afetch_!=FETCH_FROM_LOCAL || Bfetch_!=FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
if ((mS_ % vwidth_) > 0 || (nS_ % vwidth_) > 0) if ((mS_ % vwidth_) > 0 || (nS_ % vwidth_) > 0)
return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE; return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE;
@@ -165,12 +162,9 @@ int gemm::is_invalid_impl(driver::Device const &, expression_tree const &) const
if ( kS_ % kL_ == 0) if ( kS_ % kL_ == 0)
return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL; return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL;
if (Afetch_==FETCH_FROM_LOCAL || Bfetch_==FETCH_FROM_LOCAL){ if ((lf0_*lf1_) !=(ls0_*ls1_))
if ((lf0_*lf1_) !=(ls0_*ls1_)) return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT;
return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT;
}
if (Afetch_==FETCH_FROM_LOCAL)
{ {
unsigned int bound1 = (A_trans_=='N')?kL_:mL_; unsigned int bound1 = (A_trans_=='N')?kL_:mL_;
unsigned int bound0 = (A_trans_=='N')?mL_:kL_; unsigned int bound0 = (A_trans_=='N')?mL_:kL_;
@@ -182,7 +176,7 @@ int gemm::is_invalid_impl(driver::Device const &, expression_tree const &) const
return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_0_MUST_BE_NL_MULTIPLE:TEMPLATE_LOCAL_FETCH_0_MUST_BE_KL_MULTIPLE; return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_0_MUST_BE_NL_MULTIPLE:TEMPLATE_LOCAL_FETCH_0_MUST_BE_KL_MULTIPLE;
} }
if (Bfetch_==FETCH_FROM_LOCAL)
{ {
unsigned int bound1 = (B_trans_=='T')?kL_:nL_; unsigned int bound1 = (B_trans_=='T')?kL_:nL_;
unsigned int bound0 = (B_trans_=='T')?nL_:kL_; unsigned int bound0 = (B_trans_=='T')?nL_:kL_;
@@ -757,10 +751,9 @@ void gemm::enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K
gemm::gemm(unsigned int vwidth gemm::gemm(unsigned int vwidth
,int_t ls0, int_t kL, int_t ls1, int_t D ,int_t ls0, int_t kL, int_t ls1, int_t D
,int_t ms, int_t ks, int_t ns ,int_t ms, int_t ks, int_t ns
,fetch_type Afetch , fetch_type Bfetch
,int_t lf0, int_t lf1, char A_trans, char B_trans) : ,int_t lf0, int_t lf1, char A_trans, char B_trans) :
parameterized_base(vwidth, ls0, ls1), mL_(ms*ls0), kL_(kL), nL_(ns*ls1), depth_(D), mS_(ms), kS_(ks), nS_(ns), parameterized_base(vwidth, ls0, ls1), mL_(ms*ls0), kL_(kL), nL_(ns*ls1), depth_(D), mS_(ms), kS_(ks)
Afetch_(Afetch), Bfetch_(Bfetch), lf0_(lf0), lf1_(lf1), A_trans_(A_trans), B_trans_(B_trans) , nS_(ns), lf0_(lf0), lf1_(lf1), A_trans_(A_trans), B_trans_(B_trans)
{ {
if(A_trans_=='N' && B_trans_=='N') type_ = GEMM_NN; if(A_trans_=='N' && B_trans_=='N') type_ = GEMM_NN;
else if(A_trans_=='T' && B_trans_=='N') type_ = GEMM_TN; else if(A_trans_=='T' && B_trans_=='N') type_ = GEMM_TN;
@@ -795,9 +788,8 @@ void gemm::enqueue(driver::CommandQueue & queue, driver::Program const & program
gemm_nn::gemm_nn(unsigned int vwidth gemm_nn::gemm_nn(unsigned int vwidth
, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns , int_t ms, int_t ks, int_t ns
, fetch_type Afetch , fetch_type Bfetch
, int_t lf0, int_t lf1) : , int_t lf0, int_t lf1) :
gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1, 'N', 'N') gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, lf0, lf1, 'N', 'N')
{ {
} }
@@ -805,27 +797,24 @@ gemm_nn::gemm_nn(unsigned int vwidth
gemm_tn::gemm_tn(unsigned int vwidth gemm_tn::gemm_tn(unsigned int vwidth
, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns , int_t ms, int_t ks, int_t ns
, fetch_type Afetch , fetch_type Bfetch
, int_t lf0, int_t lf1) : , int_t lf0, int_t lf1) :
gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1, 'T', 'N') gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, lf0, lf1, 'T', 'N')
{ } { }
// //
gemm_nt::gemm_nt(unsigned int vwidth gemm_nt::gemm_nt(unsigned int vwidth
, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns , int_t ms, int_t ks, int_t ns
, fetch_type Afetch , fetch_type Bfetch
, int_t lf0, int_t lf1) : , int_t lf0, int_t lf1) :
gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1, 'N', 'T') gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, lf0, lf1, 'N', 'T')
{ } { }
// //
gemm_tt::gemm_tt(unsigned int vwidth gemm_tt::gemm_tt(unsigned int vwidth
, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns , int_t ms, int_t ks, int_t ns
, fetch_type Afetch , fetch_type Bfetch
, int_t lf0, int_t lf1) : , int_t lf0, int_t lf1) :
gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1, 'T', 'T') gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, lf0, lf1, 'T', 'T')
{ } { }
} }

View File

@@ -41,13 +41,6 @@ unsigned int reduce_1d::lmem_usage(expression_tree const & x) const
return ls0_*size_of(x.dtype()); return ls0_*size_of(x.dtype());
} }
int reduce_1d::is_invalid_impl(driver::Device const &, expression_tree const &) const
{
if (fetch_==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
unsigned int reduce_1d::temporary_workspace(expression_tree const &) const unsigned int reduce_1d::temporary_workspace(expression_tree const &) const
{ {
if(ng_ > 1) if(ng_ > 1)
@@ -149,7 +142,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
stream << rd->process("#scalartype #name_acc = " + neutral_element(rd->op(), backend, "#scalartype") + ";") << std::endl; stream << rd->process("#scalartype #name_acc = " + neutral_element(rd->op(), backend, "#scalartype") + ";") << std::endl;
} }
} }
element_wise_loop_1D(stream, fetch_, vwidth_, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](unsigned int vwidth) element_wise_loop_1D(stream, vwidth_, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", [&](unsigned int vwidth)
{ {
std::string dtype = append_width("#scalartype",vwidth); std::string dtype = append_width("#scalartype",vwidth);
//Fetch vector entry //Fetch vector entry
@@ -252,8 +245,8 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
return stream.str(); return stream.str();
} }
reduce_1d::reduce_1d(unsigned int vwidth, unsigned int ls, unsigned int ng, fetch_type fetch): reduce_1d::reduce_1d(unsigned int vwidth, unsigned int ls, unsigned int ng):
parameterized_base(vwidth,ls,1), ng_(ng), fetch_(fetch) parameterized_base(vwidth,ls,1), ng_(ng)
{} {}
std::vector<int_t> reduce_1d::input_sizes(expression_tree const & x) const std::vector<int_t> reduce_1d::input_sizes(expression_tree const & x) const

View File

@@ -39,13 +39,6 @@ namespace isaac
namespace templates namespace templates
{ {
int reduce_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const
{
if (fetch_==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
unsigned int reduce_2d::lmem_usage(const expression_tree&) const unsigned int reduce_2d::lmem_usage(const expression_tree&) const
{ {
return (ls0_+1)*ls1_; return (ls0_+1)*ls1_;
@@ -121,7 +114,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
std::ostringstream upper; std::ostringstream upper;
upper << "(M +" << ls1_ - 1 << ")/" << ls1_ << "*" << ls1_; upper << "(M +" << ls1_ - 1 << ")/" << ls1_ << "*" << ls1_;
element_wise_loop_1D(stream, fetch_, (reduction_type_==REDUCE_ROWS)?1:1, "r", upper.str(), "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device, [&](unsigned int cwidth) element_wise_loop_1D(stream, (reduction_type_==REDUCE_ROWS)?1:1, "r", upper.str(), "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", [&](unsigned int cwidth)
{ {
//Declare Buffers //Declare Buffers
for (symbolic::reduce_2d* rd : reductions) for (symbolic::reduce_2d* rd : reductions)
@@ -136,7 +129,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
stream << "if (r < M)" << std::endl; stream << "if (r < M)" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
element_wise_loop_1D(stream, fetch_, (reduction_type_==REDUCE_COLUMNS)?vwidth_:1, "c", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](unsigned int rwidth) element_wise_loop_1D(stream, (reduction_type_==REDUCE_COLUMNS)?vwidth_:1, "c", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", [&](unsigned int rwidth)
{ {
std::string rdtype = append_width("#scalartype", rwidth); std::string rdtype = append_width("#scalartype", rwidth);
std::string cdtype = append_width("#scalartype", cwidth); std::string cdtype = append_width("#scalartype", cwidth);
@@ -276,9 +269,9 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
return stream.str(); return stream.str();
} }
reduce_2d::reduce_2d(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, fetch_type fetch, reduce_2d::reduce_2d(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1,
operation_type_family rtype) : operation_type_family rtype) :
parameterized_base(vwidth, ls0, ls1), ng0_(ng0), ng1_(ng1), fetch_(fetch), parameterized_base(vwidth, ls0, ls1), ng0_(ng0), ng1_(ng1),
reduction_type_(rtype){ } reduction_type_(rtype){ }
std::vector<int_t> reduce_2d::input_sizes(expression_tree const & tree) const std::vector<int_t> reduce_2d::input_sizes(expression_tree const & tree) const
@@ -333,11 +326,9 @@ void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
control.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]); control.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
} }
reduce_2d_rows::reduce_2d_rows(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, reduce_2d_rows::reduce_2d_rows(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1): reduce_2d(vwidth, ls0, ls1, ng0, ng1, REDUCE_ROWS) {}
fetch_type fetch): reduce_2d(vwidth, ls0, ls1, ng0, ng1, fetch, REDUCE_ROWS) {}
reduce_2d_cols::reduce_2d_cols(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, reduce_2d_cols::reduce_2d_cols(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1): reduce_2d(vwidth, ls0, ls1, ng0, ng1, REDUCE_COLUMNS) {}
fetch_type fetch): reduce_2d(vwidth, ls0, ls1, ng0, ng1, fetch, REDUCE_COLUMNS) {}
} }

View File

@@ -29,40 +29,15 @@ namespace isaac
namespace templates namespace templates
{ {
inline void fetching_loop_info(fetch_type policy, std::string const & bound, kernel_generation_stream & stream, std::string & init, std::string & upper_bound, std::string & inc, std::string const & domain_id, std::string const & domain_size, driver::Device const &, std::string const & vwidth)
{
if (policy==FETCH_FROM_GLOBAL_STRIDED)
{
init = domain_id + "*" + vwidth;
upper_bound = bound;
inc = domain_size + "*" + vwidth;
}
else if (policy==FETCH_FROM_GLOBAL_CONTIGUOUS)
{
std::string chunk_size = "chunk_size";
std::string chunk_start = "chunk_start";
std::string chunk_end = "chunk_end";
stream << "$SIZE_T " << chunk_size << " = " << vwidth << "*(" << bound << "+" << domain_size << "-1)/(" << vwidth << ");" << std::endl;
stream << "$SIZE_T " << chunk_start << " =" << domain_id << "*" << chunk_size << ";" << std::endl;
stream << "$SIZE_T " << chunk_end << " = min(" << chunk_start << "+" << chunk_size << ", " << bound << ");" << std::endl;
init = chunk_start;
upper_bound = chunk_end;
inc = vwidth;
}
}
template<class Fun> template<class Fun>
inline void element_wise_loop_1D(kernel_generation_stream & stream, fetch_type fetch, unsigned int vwidth, inline void element_wise_loop_1D(kernel_generation_stream & stream, unsigned int vwidth,
std::string const & i, std::string const & bound, std::string const & domain_id, std::string const & domain_size, driver::Device const & device, Fun const & generate_body) std::string const & i, std::string const & bound, std::string const & domain_id, std::string const & domain_size, Fun const & generate_body)
{ {
std::string strwidth = tools::to_string(vwidth); std::string svwidth = tools::to_string(vwidth);
std::string init = domain_id + "*" + svwidth;
std::string init, upper_bound, inc; std::string lbound = bound + "/" + svwidth + "*" + svwidth;
fetching_loop_info(fetch, bound, stream, init, upper_bound, inc, domain_id, domain_size, device, strwidth); std::string inc = domain_size + "*" + svwidth;
std::string boundround = upper_bound + "/" + strwidth + "*" + strwidth; stream << "for(unsigned int " << i << " = " << init << "; " << i << " < " << lbound << "; " << i << " += " << inc << ")" << std::endl;
stream << "for(unsigned int " << i << " = " << init << "; " << i << " < " << boundround << "; " << i << " += " << inc << ")" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
generate_body(vwidth); generate_body(vwidth);
@@ -71,7 +46,7 @@ inline void element_wise_loop_1D(kernel_generation_stream & stream, fetch_type f
if (vwidth>1) if (vwidth>1)
{ {
stream << "for(unsigned int " << i << " = " << boundround << " + " << domain_id << "; " << i << " < " << bound << "; " << i << " += " + domain_size + ")" << std::endl; stream << "for(unsigned int " << i << " = " << lbound << " + " << domain_id << "; " << i << " < " << bound << "; " << i << " += " + domain_size + ")" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
generate_body(1); generate_body(1);

File diff suppressed because one or more lines are too long

View File

@@ -152,25 +152,24 @@ std::shared_ptr<templates::base> profiles::create(std::string const & op, std::s
std::shared_ptr<templates::base> profiles::create(std::string const & template_name, std::vector<int> const & x) std::shared_ptr<templates::base> profiles::create(std::string const & template_name, std::vector<int> const & x)
{ {
templates::fetch_type fetch[] = {templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_GLOBAL_STRIDED, templates::FETCH_FROM_GLOBAL_CONTIGUOUS};
if(template_name=="elementwise_1d") if(template_name=="elementwise_1d")
return std::shared_ptr<templates::base>(new templates::elementwise_1d(x[0], x[1], x[2], fetch[x[3]])); return std::shared_ptr<templates::base>(new templates::elementwise_1d(x[0], x[1], x[2]));
else if(template_name=="reduce_1d") else if(template_name=="reduce_1d")
return std::shared_ptr<templates::base>(new templates::reduce_1d(x[0], x[1], x[2], fetch[x[3]])); return std::shared_ptr<templates::base>(new templates::reduce_1d(x[0], x[1], x[2]));
else if(template_name=="elementwise_2d") else if(template_name=="elementwise_2d")
return std::shared_ptr<templates::base>(new templates::elementwise_2d(x[0], x[1], x[2], x[3], x[4], fetch[x[5]])); return std::shared_ptr<templates::base>(new templates::elementwise_2d(x[0], x[1], x[2], x[3], x[4]));
else if(template_name.find("reduce_2d_rows")!=std::string::npos) else if(template_name.find("reduce_2d_rows")!=std::string::npos)
return std::shared_ptr<templates::base>(new templates::reduce_2d_rows(x[0], x[1], x[2], x[3], x[4], fetch[x[5]])); return std::shared_ptr<templates::base>(new templates::reduce_2d_rows(x[0], x[1], x[2], x[3], x[4]));
else if(template_name.find("reduce_2d_cols")!=std::string::npos) else if(template_name.find("reduce_2d_cols")!=std::string::npos)
return std::shared_ptr<templates::base>(new templates::reduce_2d_cols(x[0], x[1], x[2], x[3], x[4], fetch[x[5]])); return std::shared_ptr<templates::base>(new templates::reduce_2d_cols(x[0], x[1], x[2], x[3], x[4]));
else if(template_name.find("gemm_nn")!=std::string::npos) else if(template_name.find("gemm_nn")!=std::string::npos)
return std::shared_ptr<templates::base>(new templates::gemm_nn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], fetch[x[8]], fetch[x[9]], x[10], x[11])); return std::shared_ptr<templates::base>(new templates::gemm_nn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9]));
else if(template_name.find("gemm_tn")!=std::string::npos) else if(template_name.find("gemm_tn")!=std::string::npos)
return std::shared_ptr<templates::base>(new templates::gemm_tn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], fetch[x[8]], fetch[x[9]], x[10], x[11])); return std::shared_ptr<templates::base>(new templates::gemm_tn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9]));
else if(template_name.find("gemm_nt")!=std::string::npos) else if(template_name.find("gemm_nt")!=std::string::npos)
return std::shared_ptr<templates::base>(new templates::gemm_nt(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], fetch[x[8]], fetch[x[9]], x[10], x[11])); return std::shared_ptr<templates::base>(new templates::gemm_nt(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9]));
else if(template_name.find("gemm_tt")!=std::string::npos) else if(template_name.find("gemm_tt")!=std::string::npos)
return std::shared_ptr<templates::base>(new templates::gemm_tt(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], fetch[x[8]], fetch[x[9]], x[10], x[11])); return std::shared_ptr<templates::base>(new templates::gemm_tt(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9]));
else else
throw std::invalid_argument("Invalid expression: " + template_name); throw std::invalid_argument("Invalid expression: " + template_name);
} }

View File

@@ -73,7 +73,7 @@ def main():
libraries += ['gnustl_shared'] libraries += ['gnustl_shared']
#Source files #Source files
src = 'src/lib/exception/api.cpp src/lib/exception/driver.cpp src/lib/value_scalar.cpp src/lib/random/rand.cpp src/lib/driver/check.cpp src/lib/driver/ndrange.cpp src/lib/driver/platform.cpp src/lib/driver/backend.cpp src/lib/driver/program.cpp src/lib/driver/command_queue.cpp src/lib/driver/event.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/driver/device.cpp src/lib/driver/program_cache.cpp src/lib/driver/buffer.cpp src/lib/driver/context.cpp src/lib/driver/dispatch.cpp src/lib/jit/generation/engine/stream.cpp src/lib/jit/generation/engine/keywords.cpp src/lib/jit/generation/reduce_1d.cpp src/lib/jit/generation/elementwise_1d.cpp src/lib/jit/generation/base.cpp src/lib/jit/generation/elementwise_2d.cpp src/lib/jit/generation/reduce_2d.cpp src/lib/jit/generation/gemm.cpp src/lib/jit/syntax/engine/object.cpp src/lib/jit/syntax/engine/macro.cpp src/lib/jit/syntax/engine/process.cpp src/lib/jit/syntax/engine/binder.cpp src/lib/jit/syntax/expression/operations.cpp src/lib/jit/syntax/expression/expression.cpp src/lib/jit/syntax/expression/preset.cpp src/lib/api/blas/clBLAS.cpp src/lib/api/blas/cublas.cpp src/lib/runtime/execute.cpp src/lib/runtime/predictors/random_forest.cpp src/lib/runtime/profiles.cpp src/lib/runtime/database.cpp src/lib/array.cpp '.split() + [os.path.join('src', 'bind', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'kernels.cpp', 'exceptions.cpp']] src = 'src/lib/runtime/predictors/random_forest.cpp src/lib/runtime/profiles.cpp src/lib/runtime/database.cpp src/lib/runtime/execute.cpp src/lib/exception/driver.cpp src/lib/exception/api.cpp src/lib/random/rand.cpp src/lib/jit/generation/elementwise_1d.cpp src/lib/jit/generation/reduce_2d.cpp src/lib/jit/generation/reduce_1d.cpp src/lib/jit/generation/base.cpp src/lib/jit/generation/gemm.cpp src/lib/jit/generation/engine/keywords.cpp src/lib/jit/generation/engine/stream.cpp src/lib/jit/generation/elementwise_2d.cpp src/lib/jit/syntax/expression/expression.cpp src/lib/jit/syntax/expression/preset.cpp src/lib/jit/syntax/expression/operations.cpp src/lib/jit/syntax/engine/binder.cpp src/lib/jit/syntax/engine/macro.cpp src/lib/jit/syntax/engine/process.cpp src/lib/jit/syntax/engine/object.cpp src/lib/value_scalar.cpp src/lib/array.cpp src/lib/api/blas/cublas.cpp src/lib/api/blas/clBLAS.cpp src/lib/driver/dispatch.cpp src/lib/driver/kernel.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/buffer.cpp src/lib/driver/event.cpp src/lib/driver/ndrange.cpp src/lib/driver/device.cpp src/lib/driver/program_cache.cpp src/lib/driver/check.cpp src/lib/driver/command_queue.cpp src/lib/driver/handle.cpp src/lib/driver/context.cpp src/lib/driver/program.cpp '.split() + [os.path.join('src', 'bind', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'kernels.cpp', 'exceptions.cpp']]
boostsrc = 'external/boost/libs/' boostsrc = 'external/boost/libs/'
for s in ['numpy','python','smart_ptr','system','thread']: for s in ['numpy','python','smart_ptr','system','thread']:
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x] src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]

View File

@@ -48,14 +48,6 @@ void export_templates()
bp::scope().attr("templates") = templates_module; bp::scope().attr("templates") = templates_module;
bp::scope template_scope = templates_module; bp::scope template_scope = templates_module;
bp::enum_<tpt::fetch_type>
("fetch_type")
.value("FETCH_FROM_LOCAL", tpt::FETCH_FROM_LOCAL)
.value("FETCH_FROM_GLOBAL_STRIDED", tpt::FETCH_FROM_GLOBAL_STRIDED)
.value("FETCH_FROM_GLOBAL_CONTIGUOUS", tpt::FETCH_FROM_GLOBAL_CONTIGUOUS);
//Base //Base
{ {
#define __PROP(name) .def_readonly(#name, &tpt::base::name) #define __PROP(name) .def_readonly(#name, &tpt::base::name)
@@ -78,17 +70,17 @@ void export_templates()
#define WRAP_TEMPLATE(name, basename, ...) bp::class_<tpt::name, std::shared_ptr<tpt::name>, bp::bases<basename>>(#name, bp::init<__VA_ARGS__>())\ #define WRAP_TEMPLATE(name, basename, ...) bp::class_<tpt::name, std::shared_ptr<tpt::name>, bp::bases<basename>>(#name, bp::init<__VA_ARGS__>())\
; ;
WRAP_TEMPLATE(elementwise_1d, tpt::parameterized_base, uint, uint, uint, tpt::fetch_type) WRAP_TEMPLATE(elementwise_1d, tpt::parameterized_base, uint, uint, uint)
WRAP_TEMPLATE(elementwise_2d, tpt::parameterized_base, uint, uint, uint, uint, uint, tpt::fetch_type) WRAP_TEMPLATE(elementwise_2d, tpt::parameterized_base, uint, uint, uint, uint, uint)
WRAP_TEMPLATE(reduce_1d, tpt::parameterized_base, uint, uint, uint, tpt::fetch_type) WRAP_TEMPLATE(reduce_1d, tpt::parameterized_base, uint, uint, uint)
WRAP_BASE(reduce_2d) WRAP_BASE(reduce_2d)
WRAP_TEMPLATE(reduce_2d_rows, tpt::reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type) WRAP_TEMPLATE(reduce_2d_rows, tpt::reduce_2d, uint, uint, uint, uint, uint)
WRAP_TEMPLATE(reduce_2d_cols, tpt::reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type) WRAP_TEMPLATE(reduce_2d_cols, tpt::reduce_2d, uint, uint, uint, uint, uint)
WRAP_BASE(gemm) WRAP_BASE(gemm)
WRAP_TEMPLATE(gemm_nn, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint) WRAP_TEMPLATE(gemm_nn, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, uint, uint)
WRAP_TEMPLATE(gemm_tn, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint) WRAP_TEMPLATE(gemm_tn, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, uint, uint)
WRAP_TEMPLATE(gemm_nt, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint) WRAP_TEMPLATE(gemm_nt, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, uint, uint)
WRAP_TEMPLATE(gemm_tt, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint) WRAP_TEMPLATE(gemm_tt, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, uint, uint)
WRAP_TEMPLATE(cublas_gemm, tpt::external_base, char, char) WRAP_TEMPLATE(cublas_gemm, tpt::external_base, char, char)

View File

@@ -34,11 +34,6 @@ import tools
from tools import profile_execution_failure from tools import profile_execution_failure
from time import sleep from time import sleep
fetch_types = [sc.templates.fetch_type.FETCH_FROM_GLOBAL_CONTIGUOUS,
sc.templates.fetch_type.FETCH_FROM_GLOBAL_STRIDED,
sc.templates.fetch_type.FETCH_FROM_LOCAL,
sc.templates.fetch_type.FETCH_FROM_LOCAL]
class GeneticOptimizer: class GeneticOptimizer:
def __init__(self, logger, naccept=500, niter=1000, cxpb=.4, mutpb=.4, popsize=10, progress_bar = None): def __init__(self, logger, naccept=500, niter=1000, cxpb=.4, mutpb=.4, popsize=10, progress_bar = None):
@@ -77,7 +72,7 @@ class GeneticOptimizer:
result = [] result = []
for off1,off2 in zip(offsets[:-1],offsets[1:]): for off1,off2 in zip(offsets[:-1],offsets[1:]):
result += [gray2int(genome[off1:off2])] result += [gray2int(genome[off1:off2])]
result = [fetch_types[x] if i in genetic_infos['categorical'] else 2**x for i,x in enumerate(result)] result = [2**x for i,x in enumerate(result)]
return result return result
def evaluate(genome): def evaluate(genome):

View File

@@ -20,7 +20,7 @@
import isaac as sc import isaac as sc
from numpy import mean, median from numpy import mean, median
from math import ceil, exp, log, sqrt from math import ceil, exp, log, sqrt
from time import time import time
profile_execution_failure = (sc.OperationNotSupported, sc.OclLaunchOutOfResources, sc.CudaLaunchOutOfResources, sc.MemObjectAllocationFailure, sc.InvalidWorkGroupSize, sc.OutOfHostMemory, sc.InvalidValue) profile_execution_failure = (sc.OperationNotSupported, sc.OclLaunchOutOfResources, sc.CudaLaunchOutOfResources, sc.MemObjectAllocationFailure, sc.InvalidWorkGroupSize, sc.OutOfHostMemory, sc.InvalidValue)
def sanitize(string, keep_chars = ['_']): def sanitize(string, keep_chars = ['_']):
@@ -55,10 +55,10 @@ def benchmark(template, tree, operation=sc.templates.gemm_nn):
return float("inf") return float("inf")
#Time #Time
while total < 1e-2: while total < 1e-2:
start = time() start = time.clock()
z, events = sc.driver.enqueue(tree) z, events = sc.driver.enqueue(tree)
queue.synchronize() queue.synchronize()
end = time() end = time.clock()
times.append(end - start) times.append(end - start)
total += times[-1] total += times[-1]
i+=1 i+=1
@@ -138,15 +138,15 @@ def external_profiles(template):
def genetic_infos_of(template): def genetic_infos_of(template):
if issubclass(template, sc.templates.elementwise_1d): if issubclass(template, sc.templates.elementwise_1d):
return {'categorical': [3], 'nbits': [3,4,4,2] } return {'categorical': [], 'nbits': [3,4,4] }
elif issubclass(template, sc.templates.reduce_1d): elif issubclass(template, sc.templates.reduce_1d):
return {'categorical': [3], 'nbits':[3,4,4,2]} return {'categorical': [], 'nbits':[3,4,4]}
elif issubclass(template, sc.templates.elementwise_2d): elif issubclass(template, sc.templates.elementwise_2d):
return {'categorical': [5], 'nbits': [3,3,3,3,4,2]} return {'categorical': [], 'nbits': [3,3,3,3,4]}
elif issubclass(template, sc.templates.reduce_2d): elif issubclass(template, sc.templates.reduce_2d):
return {'categorical': [5], 'nbits': [3,3,3,3,4,2]} return {'categorical': [], 'nbits': [3,3,3,3,4]}
elif issubclass(template, sc.templates.gemm): elif issubclass(template, sc.templates.gemm):
return {'categorical': [8,9], 'nbits': [3,3,3,3,3,2,2,2,2,2,3,3]} return {'categorical': [], 'nbits': [3,3,3,3,3,2,2,2,3,3]}
def convert(profile): def convert(profile):
if isinstance(profile, str): if isinstance(profile, str):

View File

@@ -117,20 +117,9 @@ class Tuner:
with open(os.path.join(savepath, 'X.csv')) as f: with open(os.path.join(savepath, 'X.csv')) as f:
X = [tuple(map(int, row)) for row in csv.reader(f, delimiter=',')] X = [tuple(map(int, row)) for row in csv.reader(f, delimiter=',')]
with open(os.path.join(savepath, 'profiles.csv')) as f: with open(os.path.join(savepath, 'profiles.csv')) as f:
def mmap(x): profiles = [map(int,row) for v in row for row in csv.reader(f, delimiter=',')]
if x=='FETCH_FROM_LOCAL':
return sc.templates.fetch_type.FETCH_FROM_LOCAL
if x=='FETCH_FROM_GLOBAL_CONTIGUOUS':
return sc.templates.fetch_type.FETCH_FROM_GLOBAL_CONTIGUOUS
if x=='FETCH_FROM_GLOBAL_STRIDED':
return sc.templates.fetch_type.FETCH_FROM_GLOBAL_STRIDED
return int(x)
profiles = [map(mmap,row) for v in row for row in csv.reader(f, delimiter=',')]
with open(os.path.join(savepath, 'Y.csv')) as f: with open(os.path.join(savepath, 'Y.csv')) as f:
Y = [map(float, row) for row in csv.reader(f, delimiter=',')] Y = [map(float, row) for row in csv.reader(f, delimiter=',')]
#for x in X:
# tree, _ = tools.tree_of(operation, x, context)
# Y.append([performance(x, tools.benchmark(operation(*best), tree)) for best in profiles])
except: except:
pass pass
@@ -173,6 +162,7 @@ class Tuner:
yy.append(performance(xx, time)) yy.append(performance(xx, time))
#Update dataset #Update dataset
X.append(x) X.append(x)
tree, operands = tools.tree_of(operation, x, context)
y = [performance(x,tools.benchmark(operation(*prf), tree)) for prf in profiles] y = [performance(x,tools.benchmark(operation(*prf), tree)) for prf in profiles]
Y.append(y) Y.append(y)
#Save data #Save data