More cleaning
This commit is contained in:
@@ -73,15 +73,6 @@ static const int TEMPLATE_BLOCK_SIZE_TOO_LARGE = -20;
|
||||
|
||||
class base
|
||||
{
|
||||
public:
|
||||
struct parameters_type
|
||||
{
|
||||
parameters_type(uint32_t _vwidth, int_t _ls0, int_t _ls1, int_t _nkernels);
|
||||
uint32_t vwidth;
|
||||
uint32_t ls0;
|
||||
uint32_t ls1;
|
||||
uint32_t nkernels;
|
||||
};
|
||||
private:
|
||||
virtual std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const = 0;
|
||||
public:
|
||||
@@ -97,20 +88,20 @@ public:
|
||||
};
|
||||
|
||||
|
||||
template<class TemplateType, class ParametersType>
|
||||
class base_impl : public base
|
||||
{
|
||||
private:
|
||||
virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
||||
public:
|
||||
typedef ParametersType parameters_type;
|
||||
base_impl(parameters_type const & parameters);
|
||||
base_impl(uint32_t _vwidth, int_t _ls0, int_t _ls1);
|
||||
uint32_t ls0() const;
|
||||
uint32_t ls1() const;
|
||||
/** @brief returns whether or not the profile has undefined behavior on particular device */
|
||||
int is_invalid(expression_tree const & expressions, driver::Device const & device) const;
|
||||
protected:
|
||||
parameters_type p_;
|
||||
uint32_t vwidth_;
|
||||
uint32_t ls0_;
|
||||
uint32_t ls1_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -29,24 +29,18 @@ namespace isaac
|
||||
namespace templates
|
||||
{
|
||||
|
||||
class elementwise_1d_parameters : public base::parameters_type
|
||||
{
|
||||
public:
|
||||
elementwise_1d_parameters(uint32_t _vwidth, uint32_t _group_size, uint32_t _ng, fetch_type _fetch);
|
||||
uint32_t ng;
|
||||
fetch_type fetch;
|
||||
};
|
||||
|
||||
class elementwise_1d : public base_impl<elementwise_1d, elementwise_1d_parameters>
|
||||
class elementwise_1d : public base_impl
|
||||
{
|
||||
private:
|
||||
virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
||||
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & symbols) const;
|
||||
public:
|
||||
elementwise_1d(elementwise_1d::parameters_type const & parameters);
|
||||
elementwise_1d(uint32_t _vwidth, uint32_t _group_size, uint32_t _ng, fetch_type _fetch);
|
||||
elementwise_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch);
|
||||
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
|
||||
private:
|
||||
uint32_t ng_;
|
||||
fetch_type fetch_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -30,17 +30,7 @@ namespace isaac
|
||||
namespace templates
|
||||
{
|
||||
|
||||
class elementwise_2d_parameters : public base::parameters_type
|
||||
{
|
||||
public:
|
||||
elementwise_2d_parameters(uint32_t _vwidth, uint32_t _ls0, uint32_t _ls1, uint32_t _ng0, uint32_t _ng1, fetch_type _fetch);
|
||||
|
||||
uint32_t ng0;
|
||||
uint32_t ng1;
|
||||
fetch_type fetch;
|
||||
};
|
||||
|
||||
class elementwise_2d : public base_impl<elementwise_2d, elementwise_2d_parameters>
|
||||
class elementwise_2d : public base_impl
|
||||
{
|
||||
private:
|
||||
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
||||
@@ -50,6 +40,10 @@ public:
|
||||
elementwise_2d(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2, fetch_type fetch);
|
||||
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
|
||||
private:
|
||||
uint32_t ng0_;
|
||||
uint32_t ng1_;
|
||||
fetch_type fetch_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -31,35 +31,7 @@ namespace isaac
|
||||
namespace templates
|
||||
{
|
||||
|
||||
struct gemm_parameters : public base::parameters_type
|
||||
{
|
||||
gemm_parameters(uint32_t vwidth
|
||||
,uint32_t ls0, uint32_t KL, uint32_t ls1, uint32_t D
|
||||
,uint32_t ms, uint32_t ks, uint32_t ns
|
||||
,fetch_type Afetch, fetch_type Bfetch
|
||||
,uint32_t lf0, uint32_t lf1);
|
||||
|
||||
uint32_t kL;
|
||||
uint32_t depth;
|
||||
|
||||
uint32_t mS;
|
||||
uint32_t kS;
|
||||
uint32_t nS;
|
||||
|
||||
fetch_type Afetch;
|
||||
fetch_type Bfetch;
|
||||
|
||||
uint32_t lf0;
|
||||
uint32_t lf1;
|
||||
|
||||
uint32_t mL;
|
||||
uint32_t nL;
|
||||
|
||||
bool prefetch;
|
||||
bool unroll_outer;
|
||||
};
|
||||
|
||||
class gemm : public base_impl<gemm, gemm_parameters>
|
||||
class gemm : public base_impl
|
||||
{
|
||||
private:
|
||||
uint32_t temporary_workspace(expression_tree const & expressions) const;
|
||||
@@ -71,10 +43,32 @@ private:
|
||||
value_scalar const &alpha, value_scalar const &beta, driver::Program const & program, std::string const & suffix, runtime::execution_options_type const & options);
|
||||
std::vector<int_t> infos(expression_tree const & expressions, isaac::symbolic::preset::gemm::args &arguments) const;
|
||||
public:
|
||||
gemm(gemm::parameters_type const & parameters, char A_trans, char B_trans);
|
||||
gemm(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
|
||||
, int_t lf0, int_t lf1, char A_trans, char B_trans);
|
||||
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &ctr);
|
||||
private:
|
||||
//Parameters
|
||||
uint32_t kL_;
|
||||
uint32_t depth_;
|
||||
|
||||
uint32_t mS_;
|
||||
uint32_t kS_;
|
||||
uint32_t nS_;
|
||||
|
||||
fetch_type Afetch_;
|
||||
fetch_type Bfetch_;
|
||||
|
||||
uint32_t lf0_;
|
||||
uint32_t lf1_;
|
||||
|
||||
uint32_t mL_;
|
||||
uint32_t nL_;
|
||||
|
||||
bool prefetch_;
|
||||
bool unroll_outer_;
|
||||
//
|
||||
const char A_trans_;
|
||||
const char B_trans_;
|
||||
expression_type type_;
|
||||
@@ -83,7 +77,7 @@ private:
|
||||
class gemm_nn : public gemm
|
||||
{
|
||||
public:
|
||||
gemm_nn(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
gemm_nn(uint32_t vwidth, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
|
||||
, int_t lf0, int_t lf1);
|
||||
};
|
||||
@@ -91,7 +85,7 @@ public:
|
||||
class gemm_tn : public gemm
|
||||
{
|
||||
public:
|
||||
gemm_tn(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
gemm_tn(uint32_t vwidth, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
|
||||
, int_t lf0, int_t lf1);
|
||||
};
|
||||
@@ -100,7 +94,7 @@ public:
|
||||
class gemm_nt : public gemm
|
||||
{
|
||||
public:
|
||||
gemm_nt(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
gemm_nt(uint32_t vwidth, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
|
||||
, int_t lf0, int_t lf1);
|
||||
};
|
||||
@@ -109,7 +103,7 @@ public:
|
||||
class gemm_tt : public gemm
|
||||
{
|
||||
public:
|
||||
gemm_tt(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
gemm_tt(uint32_t vwidth, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch
|
||||
, int_t lf0, int_t lf1);
|
||||
};
|
||||
|
@@ -29,16 +29,7 @@ namespace isaac
|
||||
namespace templates
|
||||
{
|
||||
|
||||
struct reduce_1d_parameters : public base::parameters_type
|
||||
{
|
||||
reduce_1d_parameters(uint32_t _vwidth,
|
||||
uint32_t _group_size, uint32_t _ng,
|
||||
fetch_type _fetch);
|
||||
uint32_t ng;
|
||||
fetch_type fetch;
|
||||
};
|
||||
|
||||
class reduce_1d : public base_impl<reduce_1d, reduce_1d_parameters>
|
||||
class reduce_1d : public base_impl
|
||||
{
|
||||
private:
|
||||
uint32_t lmem_usage(expression_tree const & expressions) const;
|
||||
@@ -49,11 +40,12 @@ private:
|
||||
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const;
|
||||
|
||||
public:
|
||||
reduce_1d(reduce_1d::parameters_type const & parameters);
|
||||
reduce_1d(uint32_t simd, uint32_t ls, uint32_t ng, fetch_type fetch);
|
||||
reduce_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch);
|
||||
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
|
||||
private:
|
||||
uint32_t ng_;
|
||||
fetch_type fetch_;
|
||||
std::vector< driver::Buffer > tmp_;
|
||||
std::vector< driver::Buffer > tmpidx_;
|
||||
};
|
||||
|
@@ -31,21 +31,11 @@ namespace isaac
|
||||
{
|
||||
namespace templates
|
||||
{
|
||||
struct reduce_2d_parameters : public base::parameters_type
|
||||
{
|
||||
reduce_2d_parameters(uint32_t _vwidth,
|
||||
uint32_t _ls0, uint32_t _ls1,
|
||||
uint32_t _ng0, uint32_t _ng1, fetch_type _fetch_policy);
|
||||
uint32_t ng0;
|
||||
uint32_t ng1;
|
||||
fetch_type fetch_policy;
|
||||
};
|
||||
|
||||
|
||||
class reduce_2d : public base_impl<reduce_2d, reduce_2d_parameters>
|
||||
{
|
||||
protected:
|
||||
reduce_2d(reduce_2d::parameters_type const & , operation_type_family);
|
||||
reduce_2d(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch, operation_type_family);
|
||||
private:
|
||||
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
||||
uint32_t lmem_usage(expression_tree const &) const;
|
||||
@@ -55,21 +45,22 @@ public:
|
||||
virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
|
||||
private:
|
||||
uint32_t ng0_;
|
||||
uint32_t ng1_;
|
||||
fetch_type fetch_;
|
||||
operation_type_family reduction_type_;
|
||||
};
|
||||
|
||||
class reduce_2d_rows : public reduce_2d
|
||||
{
|
||||
public:
|
||||
reduce_2d_rows(reduce_2d::parameters_type const &);
|
||||
reduce_2d_rows(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2, fetch_type fetch);
|
||||
reduce_2d_rows(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch);
|
||||
};
|
||||
|
||||
class reduce_2d_cols : public reduce_2d
|
||||
{
|
||||
public:
|
||||
reduce_2d_cols(reduce_2d::parameters_type const &);
|
||||
reduce_2d_cols(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2, fetch_type fetch);
|
||||
reduce_2d_cols(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch);
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -40,9 +40,6 @@ namespace isaac
|
||||
namespace templates
|
||||
{
|
||||
|
||||
base::parameters_type::parameters_type(uint32_t _vwidth, int_t _ls0, int_t _ls1, int_t _nkernels) : vwidth(_vwidth), ls0(_ls0), ls1(_ls1), nkernels(_nkernels)
|
||||
{ }
|
||||
|
||||
base::base()
|
||||
{}
|
||||
|
||||
@@ -69,24 +66,20 @@ std::string base::generate(std::string const & suffix, expression_tree const &
|
||||
return generate_impl(suffix, expression, device, mapping);
|
||||
}
|
||||
|
||||
template<class TType, class PType>
|
||||
int base_impl<TType, PType>::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
int base_impl::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{ return TEMPLATE_VALID; }
|
||||
|
||||
template<class TType, class PType>
|
||||
base_impl<TType, PType>::base_impl(parameters_type const & parameters) : base(), p_(parameters)
|
||||
base_impl::base_impl(uint32_t vwidth, int_t ls0, int_t ls1): vwidth_(vwidth), ls0_(ls0), ls1_(ls1)
|
||||
{ }
|
||||
|
||||
template<class TType, class PType>
|
||||
uint32_t base_impl<TType, PType>::ls0() const
|
||||
uint32_t base_impl::ls0() const
|
||||
{ return p_.ls0; }
|
||||
|
||||
template<class TType, class PType>
|
||||
uint32_t base_impl<TType, PType>::ls1() const
|
||||
uint32_t base_impl::ls1() const
|
||||
{ return p_.ls1; }
|
||||
|
||||
template<class TType, class PType>
|
||||
int base_impl<TType, PType>::is_invalid(expression_tree const & expressions, driver::Device const & device) const
|
||||
int base_impl::is_invalid(expression_tree const & expressions, driver::Device const & device) const
|
||||
{
|
||||
//Query device informations
|
||||
size_t lmem_available = device.local_mem_size();
|
||||
@@ -112,11 +105,5 @@ int base_impl<TType, PType>::is_invalid(expression_tree const & expressions, dr
|
||||
return is_invalid_impl(device, expressions);
|
||||
}
|
||||
|
||||
template class base_impl<elementwise_1d, elementwise_1d_parameters>;
|
||||
template class base_impl<reduce_1d, reduce_1d_parameters>;
|
||||
template class base_impl<elementwise_2d, elementwise_2d_parameters>;
|
||||
template class base_impl<reduce_2d, reduce_2d_parameters>;
|
||||
template class base_impl<gemm, gemm_parameters>;
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -36,14 +36,6 @@ namespace isaac
|
||||
namespace templates
|
||||
{
|
||||
|
||||
elementwise_1d_parameters::elementwise_1d_parameters(uint32_t _vwidth,
|
||||
uint32_t _group_size, uint32_t _ng,
|
||||
fetch_type _fetch) :
|
||||
base::parameters_type(_vwidth, _group_size, 1, 1), ng(_ng), fetch(_fetch)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
int elementwise_1d::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{
|
||||
if (p_.fetch==FETCH_FROM_LOCAL)
|
||||
@@ -118,13 +110,8 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
elementwise_1d::elementwise_1d(elementwise_1d_parameters const & parameters) :
|
||||
base_impl<elementwise_1d, elementwise_1d_parameters>(parameters)
|
||||
{}
|
||||
|
||||
elementwise_1d::elementwise_1d(uint32_t simd, uint32_t ls, uint32_t ng,
|
||||
fetch_type fetch):
|
||||
base_impl<elementwise_1d, elementwise_1d_parameters>(elementwise_1d_parameters(simd,ls,ng,fetch))
|
||||
elementwise_1d::elementwise_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch):
|
||||
base_impl(vwidth,ls), ng_(ng), fetch_(fetch)
|
||||
{}
|
||||
|
||||
|
||||
|
@@ -33,13 +33,6 @@ namespace isaac
|
||||
namespace templates
|
||||
{
|
||||
|
||||
elementwise_2d_parameters::elementwise_2d_parameters(uint32_t _vwidth,
|
||||
uint32_t _ls0, uint32_t _ls1,
|
||||
uint32_t _ng0, uint32_t _ng1,
|
||||
fetch_type _fetch) : base::parameters_type(_vwidth, _ls0, _ls1, 1), ng0(_ng0), ng1(_ng1), fetch(_fetch){ }
|
||||
|
||||
|
||||
|
||||
int elementwise_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{
|
||||
if (p_.vwidth>1)
|
||||
@@ -111,12 +104,9 @@ std::string elementwise_2d::generate_impl(std::string const & suffix, expression
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
elementwise_2d::elementwise_2d(parameters_type const & parameters) :
|
||||
base_impl<elementwise_2d, elementwise_2d_parameters>(parameters){ }
|
||||
|
||||
elementwise_2d::elementwise_2d(uint32_t simd, uint32_t ls1, uint32_t ls2,
|
||||
uint32_t ng1, uint32_t ng2, fetch_type fetch):
|
||||
base_impl<elementwise_2d, elementwise_2d_parameters>(elementwise_2d_parameters(simd, ls1, ls2, ng1, ng2, fetch))
|
||||
elementwise_2d::elementwise_2d(uint32_t vwidth, uint32_t ls0, uint32_t ls1,
|
||||
uint32_t ng0, uint32_t ng1, fetch_type fetch):
|
||||
base_impl(vwidth, ls0, ls1), ng0_(ng0), ng1_(ng1), fetch_(fetch)
|
||||
{}
|
||||
|
||||
std::vector<int_t> elementwise_2d::input_sizes(expression_tree const & expression) const{
|
||||
|
@@ -37,18 +37,6 @@ namespace isaac
|
||||
namespace templates
|
||||
{
|
||||
|
||||
gemm_parameters::gemm_parameters(uint32_t vwidth
|
||||
,uint32_t ls0, uint32_t KL, uint32_t ls1, uint32_t D
|
||||
,uint32_t ms, uint32_t ks, uint32_t ns
|
||||
,fetch_type Afetch, fetch_type Bfetch
|
||||
,uint32_t lf0, uint32_t lf1): base::parameters_type(vwidth, ls0, ls1, 1),
|
||||
kL(KL), depth(D), mS(ms), kS(ks), nS(ns), Afetch(Afetch), Bfetch(Bfetch),
|
||||
lf0(lf0), lf1(lf1),
|
||||
mL(ms*ls0), nL(ns*ls1)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
uint32_t gemm::lmem_usage(expression_tree const & expression) const
|
||||
{
|
||||
uint32_t N = 0;
|
||||
@@ -689,7 +677,13 @@ gemm_parameters::gemm_parameters(uint32_t vwidth
|
||||
return {M, N, K};
|
||||
}
|
||||
|
||||
gemm::gemm(gemm_parameters const & parameters, char A_trans, char B_trans) : base_impl<gemm, gemm_parameters>(parameters), A_trans_(A_trans), B_trans_(B_trans)
|
||||
gemm::gemm(uint32_t vwidth
|
||||
,int_t ls0, int_t kL, int_t ls1, int_t D
|
||||
,int_t ms, int_t ks, int_t ns
|
||||
,fetch_type Afetch , fetch_type Bfetch
|
||||
,int_t lf0, int_t lf1, char A_trans, char B_trans) :
|
||||
base_impl(vwidth, ls0, ls1), kL_(kL), depth_(D), mS_(ms), kS_(ks), nS_(ns),
|
||||
Afetch_(Afetch), Bfetch_(Bfetch), lf0_(lf0), lf1_(lf1), A_trans_(A_trans), B_trans_(B_trans)
|
||||
{
|
||||
if(A_trans_=='N' && B_trans_=='N') type_ = GEMM_NN;
|
||||
else if(A_trans_=='T' && B_trans_=='N') type_ = GEMM_TN;
|
||||
@@ -721,40 +715,40 @@ gemm_parameters::gemm_parameters(uint32_t vwidth
|
||||
}
|
||||
|
||||
//
|
||||
gemm_nn::gemm_nn(uint32_t simd
|
||||
gemm_nn::gemm_nn(uint32_t vwidth
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetch_type Afetch , fetch_type Bfetch
|
||||
, int_t lf0, int_t lf1) :
|
||||
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1), 'N', 'N')
|
||||
gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1, 'N', 'N')
|
||||
{
|
||||
}
|
||||
|
||||
//
|
||||
gemm_tn::gemm_tn(uint32_t simd
|
||||
gemm_tn::gemm_tn(uint32_t vwidth
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetch_type Afetch , fetch_type Bfetch
|
||||
, int_t lf0, int_t lf1) :
|
||||
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1), 'T', 'N')
|
||||
gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1, 'T', 'N')
|
||||
{ }
|
||||
|
||||
//
|
||||
gemm_nt::gemm_nt(uint32_t simd
|
||||
gemm_nt::gemm_nt(uint32_t vwidth
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetch_type Afetch , fetch_type Bfetch
|
||||
, int_t lf0, int_t lf1) :
|
||||
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1), 'N', 'T')
|
||||
gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1, 'N', 'T')
|
||||
{ }
|
||||
|
||||
//
|
||||
gemm_tt::gemm_tt(uint32_t simd
|
||||
gemm_tt::gemm_tt(uint32_t vwidth
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetch_type Afetch , fetch_type Bfetch
|
||||
, int_t lf0, int_t lf1) :
|
||||
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1), 'T', 'T')
|
||||
gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1, 'T', 'T')
|
||||
{ }
|
||||
|
||||
}
|
||||
|
@@ -35,10 +35,6 @@ namespace isaac
|
||||
{
|
||||
namespace templates
|
||||
{
|
||||
reduce_1d_parameters::reduce_1d_parameters(uint32_t _vwidth,
|
||||
uint32_t _group_size, uint32_t _ng,
|
||||
fetch_type _fetch) : base::parameters_type(_vwidth, _group_size, 1, 2), ng(_ng), fetch(_fetch)
|
||||
{ }
|
||||
|
||||
uint32_t reduce_1d::lmem_usage(expression_tree const & x) const
|
||||
{
|
||||
@@ -253,11 +249,8 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
reduce_1d::reduce_1d(reduce_1d::parameters_type const & parameters) : base_impl<reduce_1d, reduce_1d_parameters>(parameters)
|
||||
{ }
|
||||
|
||||
reduce_1d::reduce_1d(uint32_t simd, uint32_t ls, uint32_t ng, fetch_type fetch):
|
||||
base_impl<reduce_1d, reduce_1d_parameters>(reduce_1d_parameters(simd,ls,ng,fetch))
|
||||
reduce_1d::reduce_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch):
|
||||
base_impl(vwidth,ls), ng_(ng), fetch_(fetch)
|
||||
{}
|
||||
|
||||
std::vector<int_t> reduce_1d::input_sizes(expression_tree const & x) const
|
||||
|
@@ -39,15 +39,9 @@ namespace isaac
|
||||
namespace templates
|
||||
{
|
||||
|
||||
reduce_2d_parameters::reduce_2d_parameters(uint32_t _vwidth,
|
||||
uint32_t _ls0, uint32_t _ls1,
|
||||
uint32_t _ng0, uint32_t _ng1, fetch_type _fetch_policy): base::parameters_type(_vwidth, _ls0, _ls1, 1),
|
||||
ng0(_ng0), ng1(_ng1), fetch_policy(_fetch_policy) { }
|
||||
|
||||
|
||||
int reduce_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{
|
||||
if (p_.fetch_policy==FETCH_FROM_LOCAL)
|
||||
if (p_.fetch==FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
@@ -127,7 +121,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
||||
std::ostringstream upper;
|
||||
upper << "(M +" << p_.ls1 - 1 << ")/" << p_.ls1 << "*" << p_.ls1;
|
||||
|
||||
element_wise_loop_1D(stream, p_.fetch_policy, (reduction_type_==REDUCE_ROWS)?1:1, "r", upper.str(), "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device, [&](uint32_t cwidth)
|
||||
element_wise_loop_1D(stream, p_.fetch, (reduction_type_==REDUCE_ROWS)?1:1, "r", upper.str(), "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device, [&](uint32_t cwidth)
|
||||
{
|
||||
//Declare Buffers
|
||||
for (symbolic::reduce_2d* rd : reductions)
|
||||
@@ -142,7 +136,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
||||
stream << "if (r < M)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
element_wise_loop_1D(stream, p_.fetch_policy, (reduction_type_==REDUCE_COLUMNS)?p_.vwidth:1, "c", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t rwidth)
|
||||
element_wise_loop_1D(stream, p_.fetch, (reduction_type_==REDUCE_COLUMNS)?p_.vwidth:1, "c", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t rwidth)
|
||||
{
|
||||
std::string rdtype = append_width("#scalartype", rwidth);
|
||||
std::string cdtype = append_width("#scalartype", cwidth);
|
||||
@@ -282,9 +276,9 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
reduce_2d::reduce_2d(reduce_2d::parameters_type const & parameters,
|
||||
reduce_2d::reduce_2d(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch,
|
||||
operation_type_family rtype) :
|
||||
base_impl<reduce_2d, reduce_2d_parameters>(parameters),
|
||||
base_impl(vwidth, ls0, ls1), ng0_(ng0), ng1_(ng1), fetch_(fetch),
|
||||
reduction_type_(rtype){ }
|
||||
|
||||
std::vector<int_t> reduce_2d::input_sizes(expression_tree const & tree) const
|
||||
@@ -331,15 +325,11 @@ void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
|
||||
control.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
|
||||
}
|
||||
|
||||
reduce_2d_rows::reduce_2d_rows(reduce_2d_parameters const & parameters): reduce_2d(parameters, REDUCE_ROWS){}
|
||||
reduce_2d_rows::reduce_2d_rows(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1,
|
||||
fetch_type fetch): reduce_2d(vwidth, ls0, ls1, ng0, ng1, fetch, REDUCE_ROWS) {}
|
||||
|
||||
reduce_2d_rows::reduce_2d_rows(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2,
|
||||
fetch_type fetch): reduce_2d(reduce_2d_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_ROWS) {}
|
||||
|
||||
reduce_2d_cols::reduce_2d_cols(reduce_2d::parameters_type const & parameters): reduce_2d(parameters, REDUCE_COLUMNS){}
|
||||
|
||||
reduce_2d_cols::reduce_2d_cols(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2,
|
||||
fetch_type fetch): reduce_2d(reduce_2d_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_COLUMNS) {}
|
||||
reduce_2d_cols::reduce_2d_cols(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1,
|
||||
fetch_type fetch): reduce_2d(vwidth, ls0, ls1, ng0, ng1, fetch, REDUCE_COLUMNS) {}
|
||||
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user