Python: Fixed wrapper issues induced after cleaning

This commit is contained in:
Philippe Tillet
2016-10-03 02:23:20 -04:00
parent 4fbbd1a27a
commit 31849794e8
11 changed files with 210 additions and 209 deletions

View File

@@ -71,7 +71,7 @@ static const int TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE = -18;
static const int TEMPLATE_TEMPORARY_TOO_LARGE = -19;
static const int TEMPLATE_BLOCK_SIZE_TOO_LARGE = -20;
class base
class base: public std::enable_shared_from_this<base>
{
private:
virtual std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const = 0;
@@ -85,6 +85,9 @@ public:
virtual int is_invalid(expression_tree const & expressions, driver::Device const & device) const = 0;
virtual void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & expressions) = 0;
std::string generate(std::string const & suffix, expression_tree const & expressions, driver::Device const & device);
std::shared_ptr<base> getptr() {
return shared_from_this();
}
};

View File

@@ -36,8 +36,7 @@ private:
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const;
public:
elementwise_2d(parameters_type const & parameters);
elementwise_2d(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2, fetch_type fetch);
elementwise_2d(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch);
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
private:

View File

@@ -32,7 +32,7 @@ namespace isaac
namespace templates
{
class reduce_2d : public base_impl<reduce_2d, reduce_2d_parameters>
class reduce_2d : public base_impl
{
protected:
reduce_2d(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch, operation_type_family);

View File

@@ -73,12 +73,11 @@ base_impl::base_impl(uint32_t vwidth, int_t ls0, int_t ls1): vwidth_(vwidth), ls
{ }
uint32_t base_impl::ls0() const
{ return p_.ls0; }
{ return ls0_; }
uint32_t base_impl::ls1() const
{ return p_.ls1; }
{ return ls1_; }
template<class TType, class PType>
int base_impl::is_invalid(expression_tree const & expressions, driver::Device const & device) const
{
//Query device informations
@@ -90,16 +89,16 @@ int base_impl::is_invalid(expression_tree const & expressions, driver::Device c
//Invalid work group size
size_t max_workgroup_size = device.max_work_group_size();
std::vector<size_t> max_work_item_sizes = device.max_work_item_sizes();
if (p_.ls0*p_.ls1 > max_workgroup_size)
if (ls0_*ls1_ > max_workgroup_size)
return TEMPLATE_WORK_GROUP_SIZE_OVERFLOW;
if (p_.ls0 > max_work_item_sizes[0])
if (ls0_ > max_work_item_sizes[0])
return TEMPLATE_LOCAL_SIZE_0_OVERFLOW;
if (p_.ls1 > max_work_item_sizes[1])
if (ls1_ > max_work_item_sizes[1])
return TEMPLATE_LOCAL_SIZE_1_OVERFLOW;
//Invalid SIMD Width
if (p_.vwidth!=1 && p_.vwidth!=2 && p_.vwidth!=3 && p_.vwidth!=4)
if (vwidth_!=1 && vwidth_!=2 && vwidth_!=3 && vwidth_!=4)
return TEMPLATE_INVALID_SIMD_WIDTH;
return is_invalid_impl(device, expressions);

View File

@@ -38,7 +38,7 @@ namespace templates
int elementwise_1d::is_invalid_impl(driver::Device const &, expression_tree const &) const
{
if (p_.fetch==FETCH_FROM_LOCAL)
if (fetch_==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
@@ -57,7 +57,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
case driver::CUDA:
stream << "#include \"vector.h\"" << std::endl; break;
case driver::OPENCL:
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << "," << p_.ls1 << ",1)))" << std::endl; break;
stream << " __attribute__((reqd_work_group_size(" << ls0_ << "," << ls1_ << ",1)))" << std::endl; break;
}
stream << "$KERNEL void elementwise_1d" << suffix << "($SIZE_T N, " << tools::join(kernel_arguments(device, symbols, tree), ", ") << ")";
@@ -75,7 +75,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
stream.inc_tab();
}
element_wise_loop_1D(stream, p_.fetch, p_.vwidth, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t vwidth)
element_wise_loop_1D(stream, fetch_, vwidth_, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t vwidth)
{
std::string dtype = append_width("#scalartype",vwidth);
@@ -111,7 +111,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
}
elementwise_1d::elementwise_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch):
base_impl(vwidth,ls), ng_(ng), fetch_(fetch)
base_impl(vwidth,ls,1), ng_(ng), fetch_(fetch)
{}
@@ -130,8 +130,8 @@ void elementwise_1d::enqueue(driver::CommandQueue &, driver::Program const & pro
name += suffix;
driver::Kernel kernel(program, name.c_str());
//NDRange
driver::NDRange global(p_.ls0*p_.ng);
driver::NDRange local(p_.ls0);
driver::NDRange global(ls0_*ng_);
driver::NDRange local(ls0_);
//Arguments
uint32_t current_arg = 0;
kernel.setSizeArg(current_arg++, size);

View File

@@ -35,9 +35,9 @@ namespace templates
int elementwise_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const
{
if (p_.vwidth>1)
if (vwidth_>1)
return TEMPLATE_INVALID_SIMD_WIDTH;
if(p_.fetch==FETCH_FROM_LOCAL)
if(fetch_==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
@@ -60,7 +60,7 @@ std::string elementwise_2d::generate_impl(std::string const & suffix, expression
case driver::CUDA:
stream << "#include \"vector.h\"" << std::endl; break;
case driver::OPENCL:
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << "," << p_.ls1 << ",1)))" << std::endl; break;
stream << " __attribute__((reqd_work_group_size(" << ls0_ << "," << ls1_ << ",1)))" << std::endl; break;
}
stream << "$KERNEL void elementwise_2d" << suffix << "($SIZE_T M, $SIZE_T N, " << tools::join(kernel_arguments(device, symbols, tree), ", ") << ")" << std::endl;
@@ -68,11 +68,11 @@ std::string elementwise_2d::generate_impl(std::string const & suffix, expression
stream.inc_tab();
fetching_loop_info(p_.fetch, "M", stream, init0, upper_bound0, inc0, "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device);
fetching_loop_info(fetch_, "M", stream, init0, upper_bound0, inc0, "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device);
stream << "for($SIZE_T i = " << init0 << "; i < " << upper_bound0 << "; i += " << inc0 << ")" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
fetching_loop_info(p_.fetch, "N", stream, init1, upper_bound1, inc1, "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device);
fetching_loop_info(fetch_, "N", stream, init1, upper_bound1, inc1, "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device);
stream << "for($SIZE_T j = " << init1 << "; j < " << upper_bound1 << "; j += " << inc1 << ")" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
@@ -119,8 +119,8 @@ void elementwise_2d::enqueue(driver::CommandQueue & /*queue*/, driver::Program c
std::string name = "elementwise_2d";
name +=suffix;
driver::Kernel kernel(program, name.c_str());
driver::NDRange global(p_.ls0*p_.ng0, p_.ls1*p_.ng1);
driver::NDRange local(p_.ls0, p_.ls1);
driver::NDRange global(ls0_*ng0_, ls1_*ng1_);
driver::NDRange local(ls0_, ls1_);
uint32_t current_arg = 0;
std::vector<int_t> MN = input_sizes(expressions);
kernel.setSizeArg(current_arg++, MN[0]);

View File

@@ -40,10 +40,10 @@ namespace templates
uint32_t gemm::lmem_usage(expression_tree const & expression) const
{
uint32_t N = 0;
size_t llda = (A_trans_=='N')?p_.mL:p_.kL+1;
size_t lnda = (A_trans_=='N')?p_.kL:p_.mL;
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL+1;
size_t lndb = (B_trans_=='T')?p_.kL:p_.nL;
size_t llda = (A_trans_=='N')?mL_:kL_+1;
size_t lnda = (A_trans_=='N')?kL_:mL_;
size_t lldb = (B_trans_=='T')?nL_:kL_+1;
size_t lndb = (B_trans_=='T')?kL_:nL_;
N += llda*lnda;
N += lldb*lndb;
return N*size_of(expression.dtype());
@@ -51,7 +51,7 @@ namespace templates
uint32_t gemm::registers_usage(expression_tree const & expression) const
{
uint32_t N = p_.mS * p_.nS + p_.mS * p_.kS + p_.kS * p_.nS;
uint32_t N = mS_ * nS_ + mS_ * kS_ + kS_ * nS_;
return N*size_of(expression.dtype());
}
@@ -59,51 +59,51 @@ namespace templates
{
std::vector<int_t> MNK = input_sizes(expressions);
int_t M = MNK[0]; int_t N = MNK[1];
if(p_.depth > 1)
return M*N*p_.depth;
if(depth_ > 1)
return M*N*depth_;
return 0;
}
int gemm::is_invalid_impl(driver::Device const &, expression_tree const &) const
{
if(p_.Afetch!=FETCH_FROM_LOCAL || p_.Bfetch!=FETCH_FROM_LOCAL)
if(Afetch_!=FETCH_FROM_LOCAL || Bfetch_!=FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
if ((p_.mS % p_.vwidth) > 0 || (p_.nS % p_.vwidth) > 0)
if ((mS_ % vwidth_) > 0 || (nS_ % vwidth_) > 0)
return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE;
if(p_.mL > 256 || p_.nL > 256)
if(mL_ > 256 || nL_ > 256)
return TEMPLATE_BLOCK_SIZE_TOO_LARGE;
if ( p_.kS % p_.kL == 0)
if ( kS_ % kL_ == 0)
return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL;
if (p_.Afetch==FETCH_FROM_LOCAL || p_.Bfetch==FETCH_FROM_LOCAL){
if ((p_.lf0*p_.lf1) !=(p_.ls0*p_.ls1))
if (Afetch_==FETCH_FROM_LOCAL || Bfetch_==FETCH_FROM_LOCAL){
if ((lf0_*lf1_) !=(ls0_*ls1_))
return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT;
}
if (p_.Afetch==FETCH_FROM_LOCAL)
if (Afetch_==FETCH_FROM_LOCAL)
{
uint32_t bound1 = (A_trans_=='N')?p_.kL:p_.mL;
uint32_t bound0 = (A_trans_=='N')?p_.mL:p_.kL;
uint32_t bound1 = (A_trans_=='N')?kL_:mL_;
uint32_t bound0 = (A_trans_=='N')?mL_:kL_;
if (p_.lf1>0 && (bound1 % p_.lf1)> 0)
if (lf1_>0 && (bound1 % lf1_)> 0)
return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
if (p_.lf0>0 && (bound0 % (p_.lf0*p_.vwidth)) > 0)
if (lf0_>0 && (bound0 % (lf0_*vwidth_)) > 0)
return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_0_MUST_BE_NL_MULTIPLE:TEMPLATE_LOCAL_FETCH_0_MUST_BE_KL_MULTIPLE;
}
if (p_.Bfetch==FETCH_FROM_LOCAL)
if (Bfetch_==FETCH_FROM_LOCAL)
{
uint32_t bound1 = (B_trans_=='T')?p_.kL:p_.nL;
uint32_t bound0 = (B_trans_=='T')?p_.nL:p_.kL;
uint32_t bound1 = (B_trans_=='T')?kL_:nL_;
uint32_t bound0 = (B_trans_=='T')?nL_:kL_;
if (p_.lf1>0 && (bound1 % p_.lf1)> 0)
if (lf1_>0 && (bound1 % lf1_)> 0)
return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
if (p_.lf0>0 && (bound0 % (p_.lf0*p_.vwidth)) > 0)
if (lf0_>0 && (bound0 % (lf0_*vwidth_)) > 0)
return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
}
@@ -117,10 +117,10 @@ namespace templates
using tools::to_string;
driver::backend_type backend = device.backend();
bool has_depth = p_.depth > 1;
#define VLOAD(offset, ptr) vload(p_.vwidth, sdtype, offset, ptr, "1", backend, true)
#define VLOAD_MISALIGNED(offset, ptr) vload(p_.vwidth, sdtype, offset, ptr, "1", backend, false)
#define VSTORE(value, offset, ptr) vstore(p_.vwidth, sdtype, value, offset, ptr, "1", backend)
bool has_depth = depth_ > 1;
#define VLOAD(offset, ptr) vload(vwidth_, sdtype, offset, ptr, "1", backend, true)
#define VLOAD_MISALIGNED(offset, ptr) vload(vwidth_, sdtype, offset, ptr, "1", backend, false)
#define VSTORE(value, offset, ptr) vstore(vwidth_, sdtype, value, offset, ptr, "1", backend)
symbolic::preset::gemm::args args;
infos(tree, args);
@@ -134,7 +134,7 @@ namespace templates
kernel_generation_stream stream(backend);
numeric_type dtype = tree.dtype();
std::string sdtype = to_string(dtype);
std::string vdtype = append_width(sdtype, p_.vwidth);
std::string vdtype = append_width(sdtype, vwidth_);
//////////////////
/// DECLARATIONS
@@ -148,7 +148,7 @@ namespace templates
switch(backend)
{
case driver::OPENCL:
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << "," << p_.ls1 << ",1)))" << std::endl;
stream << " __attribute__((reqd_work_group_size(" << ls0_ << "," << ls1_ << ",1)))" << std::endl;
break;
default:
break;
@@ -166,20 +166,20 @@ namespace templates
///Declare
stream << "//blocks" << std::endl;
stream << sdtype << " rC[" << p_.mS << "][" << p_.nS << "] = {{0}};" << std::endl;
stream << vdtype << " rA[" << p_.kS << "][" << p_.mS/p_.vwidth << "];" << std::endl;
stream << vdtype << " rB[" << p_.kS << "][" << p_.nS/p_.vwidth << "];" << std::endl;
stream << sdtype << " rC[" << mS_ << "][" << nS_ << "] = {{0}};" << std::endl;
stream << vdtype << " rA[" << kS_ << "][" << mS_/vwidth_ << "];" << std::endl;
stream << vdtype << " rB[" << kS_ << "][" << nS_/vwidth_ << "];" << std::endl;
stream << std::endl;
stream << "//pointers" << std::endl;
size_t llda = (A_trans_=='N')?p_.mL:p_.kL+1;
size_t lnda = (A_trans_=='N')?p_.kL:p_.mL;
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL+1;
size_t lndb = (B_trans_=='T')?p_.kL:p_.nL;
size_t llda = (A_trans_=='N')?mL_:kL_+1;
size_t lnda = (A_trans_=='N')?kL_:mL_;
size_t lldb = (B_trans_=='T')?nL_:kL_+1;
size_t lndb = (B_trans_=='T')?kL_:nL_;
stream << "$LOCAL " << sdtype << " lA[" << llda*lnda << "];" << std::endl;
stream << "$LOCAL " << sdtype << " lB[" << lldb*lndb << "];" << std::endl;
uint32_t npA = p_.mL/(A_trans_=='N'?p_.lf0*p_.vwidth:p_.lf1);
uint32_t npB = p_.nL/(B_trans_=='T'?p_.lf0*p_.vwidth:p_.lf1);
uint32_t npA = mL_/(A_trans_=='N'?lf0_*vwidth_:lf1_);
uint32_t npB = nL_/(B_trans_=='T'?lf0_*vwidth_:lf1_);
stream << "$GLOBAL " << sdtype << "* Ai[" << npA << "];" << std::endl;
stream << "$GLOBAL " << sdtype << "* Bi[" << npB << "];" << std::endl;
stream << std::endl;
@@ -204,20 +204,20 @@ namespace templates
if(has_depth)
{
stream << "gidz = $GROUP_IDX_2;" << std::endl;
stream << "div = (K+" << p_.depth-1 << ")/" << p_.depth << ";" << std::endl;
stream << "div = (K+" << depth_-1 << ")/" << depth_ << ";" << std::endl;
stream << "offz = div*gidz;" << std::endl;
stream << "K = min(K - div*gidz, ($SIZE_T)div);" << std::endl;
}
stream << "idt = " << p_.ls0 << "*ids.w + ids.z;" << std::endl;
stream << "idT.y = idt/" << p_.lf0 << ";" << std::endl;
stream << "idT.x = idt - " << p_.lf0 << "*idT.y;" << std::endl;
stream << "idt = " << ls0_ << "*ids.w + ids.z;" << std::endl;
stream << "idT.y = idt/" << lf0_ << ";" << std::endl;
stream << "idT.x = idt - " << lf0_ << "*idT.y;" << std::endl;
stream << std::endl;
stream << "//Adjust pointers and bounds per work-item" << std::endl;
stream << "ids.x *= " << p_.mL << ";" << std::endl;
stream << "ids.y *= " << p_.nL << ";" << std::endl;
stream << "idT.x *= " << p_.vwidth << ";" << std::endl;
stream << "ids.x *= " << mL_ << ";" << std::endl;
stream << "ids.y *= " << nL_ << ";" << std::endl;
stream << "idT.x *= " << vwidth_ << ";" << std::endl;
stream << "M -= ids.x;" << std::endl;
if(A_trans_=='N')
@@ -280,19 +280,19 @@ namespace templates
for(uint32_t i = 0 ; i < npA ; i++ )
if (A_trans_=='N')
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.lf0*p_.vwidth) + " < M", "(int)((idT.x + " + to_string(i*p_.lf0*p_.vwidth) + ")" + ASTRIDE1 + ")", "0") << ";" << std::endl;
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*lf0_*vwidth_) + " < M", "(int)((idT.x + " + to_string(i*lf0_*vwidth_) + ")" + ASTRIDE1 + ")", "0") << ";" << std::endl;
else
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.lf1) + " < M", "(int)((idT.y + " + to_string(i*p_.lf1) + ")*lda)", "0") << ";" << std::endl;
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*lf1_) + " < M", "(int)((idT.y + " + to_string(i*lf1_) + ")*lda)", "0") << ";" << std::endl;
for(uint32_t i = 0 ; i < npB ; i++ )
if (B_trans_=='T')
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.lf0*p_.vwidth) + " < N", "(int)((idT.x + " + to_string(i*p_.lf0*p_.vwidth) + ")" + BSTRIDE1 + ")", "0") << ";" << std::endl;
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*lf0_*vwidth_) + " < N", "(int)((idT.x + " + to_string(i*lf0_*vwidth_) + ")" + BSTRIDE1 + ")", "0") << ";" << std::endl;
else
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.lf1) + " < N", "(int)((idT.y + " + to_string(i*p_.lf1) + ")*ldb)", "0") << ";" << std::endl;
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*lf1_) + " < N", "(int)((idT.y + " + to_string(i*lf1_) + ")*ldb)", "0") << ";" << std::endl;
stream << std::endl;
stream << "//Outer loop" << std::endl;
stream << "while(K >=" << p_.kL << ")" << std::endl;
stream << "while(K >=" << kL_ << ")" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
@@ -306,13 +306,13 @@ namespace templates
stream << "//Fetch A to local memory" << std::endl;
if (A_trans_=='N')
{
for(uint32_t k = 0; k < p_.kL; k += p_.lf1)
for(uint32_t m = 0; m < p_.mL; m += p_.lf0*p_.vwidth)
for(uint32_t k = 0; k < kL_; k += lf1_)
for(uint32_t m = 0; m < mL_; m += lf0_*vwidth_)
{
std::string mm = to_string(m/(p_.vwidth*p_.lf0));
std::string mm = to_string(m/(vwidth_*lf0_));
std::string kk = to_string(k);
if(last_iteration)
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
for(uint32_t s = 0 ; s < vwidth_ ; ++s)
stream << "ldsA[" << k*llda + m + s << "] = (condy" << k << " && " << s << "< M)? Ai[" << mm << "][" << k << "*lda + " << s << "] : 0;" << std::endl;
else
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Ai[" + mm +"][" + kk + "*lda]"), "0", "ldsA + " + to_string(k*llda+m)) << ";" << std::endl;
@@ -320,13 +320,13 @@ namespace templates
}
else
{
for(uint32_t k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
for(uint32_t m = 0; m < p_.mL; m += p_.lf1)
for(uint32_t k = 0; k < kL_; k += lf0_*vwidth_)
for(uint32_t m = 0; m < mL_; m += lf1_)
{
std::string mm = to_string(m/p_.lf1);
std::string mm = to_string(m/lf1_);
std::string kk = to_string(k);
if(last_iteration)
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
for(uint32_t s = 0 ; s < vwidth_ ; ++s)
stream << "ldsA[" << m*llda + k + s << "] = condx" << k + s << "? Ai[" << mm << "][" << k + s << ASTRIDE1 << "] : 0;" << std::endl;
else
@@ -337,13 +337,13 @@ namespace templates
stream << "//Fetch B to local memory" << std::endl;
if (B_trans_=='T')
{
for(uint32_t k = 0; k < p_.kL; k += p_.lf1)
for(uint32_t n = 0; n < p_.nL; n += p_.lf0*p_.vwidth)
for(uint32_t k = 0; k < kL_; k += lf1_)
for(uint32_t n = 0; n < nL_; n += lf0_*vwidth_)
{
std::string nn = to_string(n/(p_.vwidth*p_.lf0));
std::string nn = to_string(n/(vwidth_*lf0_));
std::string kk = to_string(k);
if(last_iteration)
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
for(uint32_t s = 0 ; s < vwidth_ ; ++s)
stream << "ldsB[" << k*lldb + n + s << "] = (condy" << k << " && " << s << "< N)? Bi[" << nn << "][" << kk << "*ldb +" << s << "] : 0;" << std::endl;
else
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Bi[" + nn +"][" + kk + "*ldb]"), "0", "ldsB + " + to_string(k*lldb+n)) << ";" << std::endl;
@@ -351,13 +351,13 @@ namespace templates
}
else
{
for(uint32_t k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
for(uint32_t n = 0; n < p_.nL; n += p_.lf1)
for(uint32_t k = 0; k < kL_; k += lf0_*vwidth_)
for(uint32_t n = 0; n < nL_; n += lf1_)
{
std::string nn = to_string(n/p_.lf1);
std::string nn = to_string(n/lf1_);
std::string kk = to_string(k);
if(last_iteration)
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
for(uint32_t s = 0 ; s < vwidth_ ; ++s)
stream << "ldsB[" << n*lldb + k + s << "] = condx" << k + s << "? Bi[" << nn << "][" << k + s << BSTRIDE1 << "] : 0;" << std::endl;
else
@@ -366,18 +366,18 @@ namespace templates
}
if(A_trans_=='N')
stream << "ldsA = lA + ids.z*" << p_.vwidth << ";" << std::endl;
stream << "ldsA = lA + ids.z*" << vwidth_ << ";" << std::endl;
else
stream << "ldsA = lA + ids.z*" << llda*p_.vwidth << ";" << std::endl;
stream << "ldsA = lA + ids.z*" << llda*vwidth_ << ";" << std::endl;
if(B_trans_=='T')
stream << "ldsB = lB + ids.w*" << p_.vwidth << ";" << std::endl;
stream << "ldsB = lB + ids.w*" << vwidth_ << ";" << std::endl;
else
stream << "ldsB = lB + ids.w*" << lldb*p_.vwidth << ";" << std::endl;
stream << "ldsB = lB + ids.w*" << lldb*vwidth_ << ";" << std::endl;
stream << "$LOCAL_BARRIER;" << std::endl;
std::string bound = last_iteration?"K":tools::to_string(p_.kL);
size_t ks = last_iteration?1:p_.kS;
std::string bound = last_iteration?"K":tools::to_string(kL_);
size_t ks = last_iteration?1:kS_;
stream << "//Inner loop" << std::endl;
stream << "for(uint32_t k = 0; k < " << bound << "; k+=" << ks << "){" << std::endl;
stream.inc_tab();
@@ -385,19 +385,19 @@ namespace templates
stream << "//Fetch A to registers" << std::endl;
stream << "#pragma unroll" << std::endl;
stream << "for(uint32_t kk = 0; kk < " << ks << "; kk++)" << std::endl;
stream << "#pragma unroll " << p_.mS/p_.vwidth << std::endl;
stream << "for(uint32_t mm = 0; mm < " << p_.mS/p_.vwidth << "; mm++)" << std::endl;
stream << "#pragma unroll " << mS_/vwidth_ << std::endl;
stream << "for(uint32_t mm = 0; mm < " << mS_/vwidth_ << "; mm++)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
if(A_trans_=='N')
stream << "rA[kk][mm] = " << VLOAD("0", "ldsA + k*" + to_string(llda) + " + mm*" + to_string(p_.ls0*p_.vwidth) + "+ kk*" + to_string(llda)) << ";" << std::endl;
stream << "rA[kk][mm] = " << VLOAD("0", "ldsA + k*" + to_string(llda) + " + mm*" + to_string(ls0_*vwidth_) + "+ kk*" + to_string(llda)) << ";" << std::endl;
else
{
if(p_.vwidth==1)
stream << "rA[kk][mm] = ldsA[k + mm*" << p_.ls0*llda << "+ kk" << "];" << std::endl;
if(vwidth_==1)
stream << "rA[kk][mm] = ldsA[k + mm*" << ls0_*llda << "+ kk" << "];" << std::endl;
else
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
stream << access_vector_type("rA[kk][mm]", s) << " = ldsA[k + (mm*" << p_.vwidth*p_.ls0 << " + " << s << ")*" << llda << "+ kk];" << std::endl;
for(uint32_t s = 0 ; s < vwidth_ ; ++s)
stream << access_vector_type("rA[kk][mm]", s) << " = ldsA[k + (mm*" << vwidth_*ls0_ << " + " << s << ")*" << llda << "+ kk];" << std::endl;
}
stream.dec_tab();
@@ -406,19 +406,19 @@ namespace templates
stream << "//Fetch B to registers" << std::endl;
stream << "#pragma unroll " << ks << std::endl;
stream << "for(uint32_t kk = 0; kk < " << ks << "; kk++)" << std::endl;
stream << "#pragma unroll " << p_.nS/p_.vwidth << std::endl;
stream << "for(uint32_t nn = 0; nn < " << p_.nS/p_.vwidth << "; nn++)" << std::endl;
stream << "#pragma unroll " << nS_/vwidth_ << std::endl;
stream << "for(uint32_t nn = 0; nn < " << nS_/vwidth_ << "; nn++)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
if(B_trans_=='T')
stream << "rB[kk][nn] = " << VLOAD("0", "ldsB + k*" + to_string(lldb) + " + nn*" + to_string(p_.ls1*p_.vwidth) + "+ kk*" + to_string(lldb)) << ";" << std::endl;
stream << "rB[kk][nn] = " << VLOAD("0", "ldsB + k*" + to_string(lldb) + " + nn*" + to_string(ls1_*vwidth_) + "+ kk*" + to_string(lldb)) << ";" << std::endl;
else
{
if(p_.vwidth==1)
stream << "rB[kk][nn] = ldsB[k" << " + nn*" << p_.ls1*lldb << "+ kk" << "];" << std::endl;
if(vwidth_==1)
stream << "rB[kk][nn] = ldsB[k" << " + nn*" << ls1_*lldb << "+ kk" << "];" << std::endl;
else
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
stream << access_vector_type("rB[kk][nn]", s) << " = ldsB[k" << " + (nn*" << p_.vwidth*p_.ls1 << " + " << s << ")*" << lldb << "+ kk];" << std::endl;
for(uint32_t s = 0 ; s < vwidth_ ; ++s)
stream << access_vector_type("rB[kk][nn]", s) << " = ldsB[k" << " + (nn*" << vwidth_*ls1_ << " + " << s << ")*" << lldb << "+ kk];" << std::endl;
}
stream.dec_tab();
stream << "}" << std::endl;
@@ -427,41 +427,41 @@ namespace templates
stream << "#pragma unroll" << std::endl;
stream << "for(uint32_t kk = 0 ; kk < " << ks << "; ++kk){" << std::endl;
stream.inc_tab();
for(uint32_t nn=0; nn < p_.nS; ++nn)
for(uint32_t mm=0; mm < p_.mS; ++mm){
for(uint32_t nn=0; nn < nS_; ++nn)
for(uint32_t mm=0; mm < mS_; ++mm){
string res_str, lhs_str, rhs_str;
res_str = "rC[" + to_string(mm) + "][" + to_string(nn) + "]";
if (p_.vwidth==1)
if (vwidth_==1)
lhs_str = "rA[kk][" + to_string(mm) + "]";
else
lhs_str = access_vector_type("rA[kk][" + to_string(mm/p_.vwidth) + "]", mm%p_.vwidth);
if (p_.vwidth==1)
lhs_str = access_vector_type("rA[kk][" + to_string(mm/vwidth_) + "]", mm%vwidth_);
if (vwidth_==1)
rhs_str = "rB[kk]["+to_string(nn)+"]";
else
rhs_str = access_vector_type("rB[kk]["+to_string(nn/p_.vwidth)+"]", nn%p_.vwidth);
rhs_str = access_vector_type("rB[kk]["+to_string(nn/vwidth_)+"]", nn%vwidth_);
stream << res_str << "= $MAD(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl;
}
stream.dec_tab();
stream << "}" << std::endl;
stream.dec_tab();
stream << "}" << std::endl;
stream << "K -= " << p_.kL << ";" << std::endl;
stream << "K -= " << kL_ << ";" << std::endl;
//Increment A pointers to global memory
if (A_trans_=='N')
for(uint32_t i = 0 ; i < npA ; ++i)
stream << "Ai[" << i << "] += " << p_.kL << "*lda;" << std::endl;
stream << "Ai[" << i << "] += " << kL_ << "*lda;" << std::endl;
else
for(uint32_t i = 0 ; i < npA ; ++i)
stream << "Ai[" << i << "] += " << p_.kL << ASTRIDE1 << ";" << std::endl;
stream << "Ai[" << i << "] += " << kL_ << ASTRIDE1 << ";" << std::endl;
//Increment B pointers to global memory
if (B_trans_=='T')
for(uint32_t i = 0 ; i < npB ; ++i)
stream << "Bi[" << i << "] += " << p_.kL << "*ldb;" << std::endl;
stream << "Bi[" << i << "] += " << kL_ << "*ldb;" << std::endl;
else
for(uint32_t i = 0 ; i < npB ; ++i)
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
stream << "Bi[" << i << "] += " << kL_ << BSTRIDE1 << ";" << std::endl;
};
fetch_to_lds(false);
stream.dec_tab();
@@ -471,15 +471,15 @@ namespace templates
if(A_trans_=='N' || B_trans_=='T')
{
stream << "int Ky = K - idT.y;" << std::endl;
for(uint32_t k = 0; k < p_.kL; k += p_.lf1)
for(uint32_t k = 0; k < kL_; k += lf1_)
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
}
if(A_trans_=='T' || B_trans_=='N')
{
stream << "int Kx = K - idT.x;" << std::endl;
for(uint32_t k = 0 ; k < p_.kL ; k += p_.lf0*p_.vwidth)
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
for(uint32_t k = 0 ; k < kL_ ; k += lf0_*vwidth_)
for(uint32_t s = 0 ; s < vwidth_ ; ++s)
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
}
fetch_to_lds(true);
@@ -498,35 +498,35 @@ namespace templates
stream << "N += ids.y;" << std::endl;
stream << "C += ids.x" << CSTRIDE1 << ";" << std::endl;
stream << "C += ids.z*" << p_.vwidth << CSTRIDE1 << ";" << std::endl;
stream << "C += ids.z*" << vwidth_ << CSTRIDE1 << ";" << std::endl;
stream << "C += ids.y*ldc;" << std::endl;
stream << "C += ids.w*" << p_.vwidth << "*ldc;" << std::endl;
stream << "C += ids.w*" << vwidth_ << "*ldc;" << std::endl;
if(has_depth)
stream << "C += gidz*ldc*N;" << std::endl;
stream << "M -= ids.x;" << std::endl;
stream << "M -= ids.z*" << p_.vwidth << ";" << std::endl;
stream << "M -= ids.z*" << vwidth_ << ";" << std::endl;
stream << "N -= ids.y;" << std::endl;
stream << "N -= ids.w*" << p_.vwidth << ";" << std::endl;
stream << "N -= ids.w*" << vwidth_ << ";" << std::endl;
for(uint32_t n=0; n < p_.nS; ++n)
for(uint32_t n=0; n < nS_; ++n)
{
string Cj = to_string((n/p_.vwidth)*(p_.ls1*p_.vwidth) + n%p_.vwidth);
string Cj = to_string((n/vwidth_)*(ls1_*vwidth_) + n%vwidth_);
stream << "if(" << Cj << " >= N) return;" << std::endl;
for(uint32_t m=0; m < p_.mS; ++m)
for(uint32_t m=0; m < mS_; ++m)
stream << "rC[" << m << "][" << n << "] *= alpha;" << std::endl;
for(uint32_t m=0; m < p_.mS; ++m)
for(uint32_t m=0; m < mS_; ++m)
{
string Ci = to_string((m/p_.vwidth)*(p_.ls0*p_.vwidth) + m%p_.vwidth);
string Ci = to_string((m/vwidth_)*(ls0_*vwidth_) + m%vwidth_);
stream << "if(" << Ci << "< M) ";
if(has_depth)
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "];" << std::endl;
else
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "] + ((beta != (" << sdtype << ")0)?(beta*" << "C[" << Ci << CSTRIDE1 << "]):0);" << std::endl;
}
if((n+1)%p_.vwidth==0){
stream << "C += ldc*" << p_.ls1*p_.vwidth - p_.vwidth + 1 << ";" << std::endl;
if((n+1)%vwidth_==0){
stream << "C += ldc*" << ls1_*vwidth_ - vwidth_ + 1 << ";" << std::endl;
}
else{
stream << "C += ldc;" << std::endl;
@@ -594,8 +594,8 @@ namespace templates
reduce_name += suffix;
driver::Kernel gemm(program, gemm_name.c_str());
driver::NDRange local(p_.ls0, p_.ls1, 1);
driver::NDRange global(align(align(M,p_.mS)/p_.mS, p_.ls0), align(align(N,p_.nS)/p_.nS, p_.ls1), p_.depth);
driver::NDRange local(ls0_, ls1_, 1);
driver::NDRange global(align(align(M,mS_)/mS_, ls0_), align(align(N,nS_)/nS_, ls1_), depth_);
uint32_t current_arg = 0;
@@ -603,7 +603,7 @@ namespace templates
gemm.setSizeArg(current_arg++, M);
gemm.setSizeArg(current_arg++, N);
gemm.setSizeArg(current_arg++, K);
if(p_.depth==1)
if(depth_==1)
{
if(backend==driver::OPENCL)
gemm.setArg(current_arg++, C.array.handle.cl);
@@ -642,15 +642,15 @@ namespace templates
gemm.setArg(current_arg++, beta);
options.enqueue(program.context(), gemm, global, local);
if(p_.depth > 1)
if(depth_ > 1)
{
uint32_t current_arg = 0;
driver::Kernel reduce(program, reduce_name.c_str());
driver::NDRange local(p_.ls0, p_.ls1);
driver::NDRange global(align(M, p_.ls0), align(N, p_.ls1));
driver::NDRange local(ls0_, ls1_);
driver::NDRange global(align(M, ls0_), align(N, ls1_));
reduce.setSizeArg(current_arg++, M);
reduce.setSizeArg(current_arg++, N);
reduce.setSizeArg(current_arg++, p_.depth);
reduce.setSizeArg(current_arg++, depth_);
reduce.setArg(current_arg++, workspace);
reduce.setSizeArg(current_arg++, M);
if(backend==driver::OPENCL)
@@ -682,7 +682,7 @@ namespace templates
,int_t ms, int_t ks, int_t ns
,fetch_type Afetch , fetch_type Bfetch
,int_t lf0, int_t lf1, char A_trans, char B_trans) :
base_impl(vwidth, ls0, ls1), kL_(kL), depth_(D), mS_(ms), kS_(ks), nS_(ns),
base_impl(vwidth, ls0, ls1), mL_(ms*ls0), kL_(kL), nL_(ns*ls1), depth_(D), mS_(ms), kS_(ks), nS_(ns),
Afetch_(Afetch), Bfetch_(Bfetch), lf0_(lf0), lf1_(lf1), A_trans_(A_trans), B_trans_(B_trans)
{
if(A_trans_=='N' && B_trans_=='N') type_ = GEMM_NN;

View File

@@ -38,20 +38,20 @@ namespace templates
uint32_t reduce_1d::lmem_usage(expression_tree const & x) const
{
return p_.ls0*size_of(x.dtype());
return ls0_*size_of(x.dtype());
}
int reduce_1d::is_invalid_impl(driver::Device const &, expression_tree const &) const
{
if (p_.fetch==FETCH_FROM_LOCAL)
if (fetch_==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
uint32_t reduce_1d::temporary_workspace(expression_tree const &) const
{
if(p_.ng > 1)
return p_.ng;
if(ng_ > 1)
return ng_;
return 0;
}
@@ -99,13 +99,13 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
if (is_indexing(rd->op().type))
{
stream << rd->process("$GLOBAL uint* #name_temp = ($GLOBAL uint *)(tmp + " + tools::to_string(offset) + ");");
offset += 4*p_.ng;
offset += 4*ng_;
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp_value = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + ");");
offset += size_of(dtype)*p_.ng;
offset += size_of(dtype)*ng_;
}
else{
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + ");");
offset += size_of(dtype)*p_.ng;
offset += size_of(dtype)*ng_;
}
}
};
@@ -118,7 +118,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
case driver::CUDA:
stream << "#include \"vector.h\"" << std::endl; break;
case driver::OPENCL:
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << ",1,1)))" << std::endl; break;
stream << " __attribute__((reqd_work_group_size(" << ls0_ << ",1,1)))" << std::endl; break;
}
stream << "$KERNEL void prod" << suffix << "($SIZE_T N, $GLOBAL char* tmp," << tools::join(kernel_arguments(device, symbols, tree), ", ") << ")" << std::endl;
stream << "{" << std::endl;
@@ -135,18 +135,18 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
{
if(is_indexing(rd->op().type))
{
stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(p_.ls0) + "];") << std::endl;
stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(ls0_) + "];") << std::endl;
stream << rd->process("#scalartype #name_acc_value = " + neutral_element(rd->op(), backend, "#scalartype") + ";") << std::endl;
stream << rd->process("$LOCAL uint32_t #name_buf[" + tools::to_string(p_.ls0) + "];") << std::endl;
stream << rd->process("$LOCAL uint32_t #name_buf[" + tools::to_string(ls0_) + "];") << std::endl;
stream << rd->process("uint32_t #name_acc = 0;") << std::endl;
}
else
{
stream << rd->process("$LOCAL #scalartype #name_buf[" + tools::to_string(p_.ls0) + "];") << std::endl;
stream << rd->process("$LOCAL #scalartype #name_buf[" + tools::to_string(ls0_) + "];") << std::endl;
stream << rd->process("#scalartype #name_acc = " + neutral_element(rd->op(), backend, "#scalartype") + ";") << std::endl;
}
}
element_wise_loop_1D(stream, p_.fetch, p_.vwidth, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t vwidth)
element_wise_loop_1D(stream, fetch_, vwidth_, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t vwidth)
{
std::string dtype = append_width("#scalartype",vwidth);
//Fetch vector entry
@@ -174,7 +174,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
stream << rd->process("#name_buf[lid] = #name_acc;") << std::endl;
}
//Reduce local memory
reduce_1d_local_memory(stream, p_.ls0, reductions, "#name_buf", "#name_buf_value", backend);
reduce_1d_local_memory(stream, ls0_, reductions, "#name_buf", "#name_buf_value", backend);
//Write to temporary buffers
stream << "if (lid==0)" << std::endl;
stream << "{" << std::endl;
@@ -205,19 +205,19 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
{
if (is_indexing(rd->op().type))
{
stream << rd->process("$LOCAL uint32_t #name_buf[" + tools::to_string(p_.ls0) + "];");
stream << rd->process("$LOCAL uint32_t #name_buf[" + tools::to_string(ls0_) + "];");
stream << rd->process("uint32_t #name_acc = 0;") << std::endl;
stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(p_.ls0) + "];") << std::endl;
stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(ls0_) + "];") << std::endl;
stream << rd->process("#scalartype #name_acc_value = " + neutral_element(rd->op(), backend, "#scalartype") + ";");
}
else
{
stream << rd->process("$LOCAL #scalartype #name_buf[" + tools::to_string(p_.ls0) + "];") << std::endl;
stream << rd->process("$LOCAL #scalartype #name_buf[" + tools::to_string(ls0_) + "];") << std::endl;
stream << rd->process("#scalartype #name_acc = " + neutral_element(rd->op(), backend, "#scalartype") + ";");
}
}
//Private reduction
stream << "for(uint32_t i = lid; i < " << p_.ng << "; i += lsize)" << std::endl;
stream << "for(uint32_t i = lid; i < " << ng_ << "; i += lsize)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
for (symbolic::reduce_1d* rd: reductions)
@@ -234,7 +234,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
stream << rd->process("#name_buf[lid] = #name_acc;") << std::endl;
}
//Local reduction
reduce_1d_local_memory(stream, p_.ls0, reductions, "#name_buf", "#name_buf_value", backend);
reduce_1d_local_memory(stream, ls0_, reductions, "#name_buf", "#name_buf_value", backend);
//Write
stream << "if (lid==0)" << std::endl;
stream << "{" << std::endl;
@@ -250,7 +250,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
}
reduce_1d::reduce_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch):
base_impl(vwidth,ls), ng_(ng), fetch_(fetch)
base_impl(vwidth,ls,1), ng_(ng), fetch_(fetch)
{}
std::vector<int_t> reduce_1d::input_sizes(expression_tree const & x) const
@@ -275,8 +275,8 @@ void reduce_1d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
driver::Kernel kernels[2] = { driver::Kernel(program,name[0].c_str()), driver::Kernel(program,name[1].c_str()) };
//NDRange
driver::NDRange global[2] = { driver::NDRange(p_.ls0*p_.ng), driver::NDRange(p_.ls0) };
driver::NDRange local[2] = { driver::NDRange(p_.ls0), driver::NDRange(p_.ls0) };
driver::NDRange global[2] = { driver::NDRange(ls0_*ng_), driver::NDRange(ls0_) };
driver::NDRange local[2] = { driver::NDRange(ls0_), driver::NDRange(ls0_) };
//Arguments
for (auto & kernel : kernels)
{

View File

@@ -41,22 +41,22 @@ namespace templates
int reduce_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const
{
if (p_.fetch==FETCH_FROM_LOCAL)
if (fetch_==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
uint32_t reduce_2d::lmem_usage(const expression_tree&) const
{
return (p_.ls0+1)*p_.ls1;
return (ls0_+1)*ls1_;
}
uint32_t reduce_2d::temporary_workspace(expression_tree const & expressions) const
{
std::vector<int_t> MN = input_sizes(expressions);
int_t M = MN[0];
if(p_.ng0 > 1)
return M*p_.ng0;
if(ng0_ > 1)
return M*ng0_;
return 0;
}
@@ -74,7 +74,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
name[0] += suffix;
name[1] += suffix;
uint32_t ldls = p_.ls0;
uint32_t ldls = ls0_;
std::string ls0ldstr = to_string(ldls);
auto unroll_tmp = [&]()
@@ -87,13 +87,13 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
if (is_indexing(rd->op().type))
{
stream << rd->process("$GLOBAL uint* #name_temp = ($GLOBAL uint*)(tmp + " + tools::to_string(offset) + "*M);");
offset += 4*p_.ng0;
offset += 4*ng0_;
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp_value = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
offset += size_of(dtype)*p_.ng0;
offset += size_of(dtype)*ng0_;
}
else{
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
offset += size_of(dtype)*p_.ng0;
offset += size_of(dtype)*ng0_;
}
}
};
@@ -107,7 +107,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
stream << "#include \"vector.h\"" << std::endl;
break;
case driver::OPENCL:
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << "," << p_.ls1 << ",1)))" << std::endl;
stream << " __attribute__((reqd_work_group_size(" << ls0_ << "," << ls1_ << ",1)))" << std::endl;
break;
}
stream << "$KERNEL void " << name[0] << "($SIZE_T M, $SIZE_T N, $GLOBAL char* tmp, " << tools::join(kernel_arguments(device, symbols, tree), ", ") << ")" << std::endl;
@@ -119,13 +119,13 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
stream << "$SIZE_T lidy = $LOCAL_IDX_1;" << std::endl;
//Loop r
std::ostringstream upper;
upper << "(M +" << p_.ls1 - 1 << ")/" << p_.ls1 << "*" << p_.ls1;
upper << "(M +" << ls1_ - 1 << ")/" << ls1_ << "*" << ls1_;
element_wise_loop_1D(stream, p_.fetch, (reduction_type_==REDUCE_ROWS)?1:1, "r", upper.str(), "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device, [&](uint32_t cwidth)
element_wise_loop_1D(stream, fetch_, (reduction_type_==REDUCE_ROWS)?1:1, "r", upper.str(), "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device, [&](uint32_t cwidth)
{
//Declare Buffers
for (symbolic::reduce_2d* rd : reductions)
stream << rd->process("$LOCAL " + append_width("#scalartype", cwidth) + " #name_buf[" + to_string(p_.ls1*ldls) + "];") << std::endl;
stream << rd->process("$LOCAL " + append_width("#scalartype", cwidth) + " #name_buf[" + to_string(ls1_*ldls) + "];") << std::endl;
//Accumulators
for (symbolic::reduce_2d* rd : reductions){
@@ -136,7 +136,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
stream << "if (r < M)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
element_wise_loop_1D(stream, p_.fetch, (reduction_type_==REDUCE_COLUMNS)?p_.vwidth:1, "c", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t rwidth)
element_wise_loop_1D(stream, fetch_, (reduction_type_==REDUCE_COLUMNS)?vwidth_:1, "c", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t rwidth)
{
std::string rdtype = append_width("#scalartype", rwidth);
std::string cdtype = append_width("#scalartype", cwidth);
@@ -167,7 +167,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
stream << rd->process("#name_buf[lidy*" + ls0ldstr + "+ lidx] = #name_acc;") << std::endl;
//Reduce local memory
stream << "#pragma unroll" << std::endl;
stream << "for($SIZE_T stride = " << p_.ls0/2 << "; stride >0; stride /=2)" << std::endl;
stream << "for($SIZE_T stride = " << ls0_/2 << "; stride >0; stride /=2)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
stream << "$LOCAL_BARRIER;" << std::endl;
@@ -189,7 +189,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
stream << "if (r < M && lidx == 0)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
if(p_.ng0==1)
if(ng0_==1)
for(size_t idx: assignments)
for(size_t s = 0 ; s < cwidth ; ++s)
stream << symbols.at(idx)->evaluate({{"leaf", "at(r+" + to_string(s) + ")"},
@@ -211,17 +211,17 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
/* ------------------------
* Kernel 2
* -----------------------*/
if(p_.ng0>1)
if(ng0_>1)
{
if(backend==driver::OPENCL)
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << "," << p_.ls1 << ",1)))" << std::endl;
stream << " __attribute__((reqd_work_group_size(" << ls0_ << "," << ls1_ << ",1)))" << std::endl;
stream << "$KERNEL void " << name[1] << "($SIZE_T M, $SIZE_T N , $GLOBAL char* tmp, " << tools::join(kernel_arguments(device, symbols, tree), ", ") << ")" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
unroll_tmp();
for (symbolic::reduce_2d* rd : reductions)
stream << rd->process("$LOCAL #scalartype #name_buf[" + to_string(p_.ls1*ldls) + "];") << std::endl;
stream << "for($SIZE_T r = $GLOBAL_IDX_1; r < (M +" << p_.ls1 - 1 << ")/" << p_.ls1 << "*" << p_.ls1 << "; r += " << GlobalSize1(backend) << "){" << std::endl;
stream << rd->process("$LOCAL #scalartype #name_buf[" + to_string(ls1_*ldls) + "];") << std::endl;
stream << "for($SIZE_T r = $GLOBAL_IDX_1; r < (M +" << ls1_ - 1 << ")/" << ls1_ << "*" << ls1_ << "; r += " << GlobalSize1(backend) << "){" << std::endl;
stream.inc_tab();
stream << "$SIZE_T lidx = $LOCAL_IDX_0;" << std::endl;
stream << "$SIZE_T lidy = $LOCAL_IDX_1;" << std::endl;
@@ -230,7 +230,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
stream << "if (r < M)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
stream << "for($SIZE_T c = lidx; c < " << p_.ng0 << "; c += $LOCAL_SIZE_0){" << std::endl;
stream << "for($SIZE_T c = lidx; c < " << ng0_ << "; c += $LOCAL_SIZE_0){" << std::endl;
stream.inc_tab();
for (symbolic::reduce_2d* rd: reductions)
compute_reduce_1d(stream, rd->process("#name_acc"), rd->process("#name_temp[r + M*c]"), rd->op());
@@ -241,7 +241,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
for (symbolic::reduce_2d* rd : reductions)
stream << rd->process("#name_buf[lidy*" + ls0ldstr + "+ lidx] = #name_acc;") << std::endl;
stream << "#pragma unroll" << std::endl;
stream << "for($SIZE_T stride = " << p_.ls0/2 << "; stride >0; stride /=2)" << std::endl;
stream << "for($SIZE_T stride = " << ls0_/2 << "; stride >0; stride /=2)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
stream << "$LOCAL_BARRIER;" << std::endl;
@@ -300,7 +300,7 @@ void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
name[0] += suffix;
name[1] += suffix;
uint32_t nk = (p_.ng0==1)?1:2;
uint32_t nk = (ng0_==1)?1:2;
std::vector<driver::Kernel> kernels;
for(uint32_t k = 0 ; k < nk ; ++k)
@@ -319,8 +319,8 @@ void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
}
//NDRange
driver::NDRange global[2] = { driver::NDRange(p_.ls0*p_.ng0, p_.ls1*p_.ng1), driver::NDRange(p_.ls0, p_.ls1*p_.ng1) };
driver::NDRange local[2] = { driver::NDRange(p_.ls0, p_.ls1), driver::NDRange(p_.ls0, p_.ls1) };
driver::NDRange global[2] = { driver::NDRange(ls0_*ng0_, ls1_*ng1_), driver::NDRange(ls0_, ls1_*ng1_) };
driver::NDRange local[2] = { driver::NDRange(ls0_, ls1_), driver::NDRange(ls0_, ls1_) };
for(uint32_t i = 0 ; i < nk ; ++i)
control.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
}

View File

@@ -23,6 +23,8 @@
#include "common.hpp"
#include "core.h"
namespace tpt = sc::templates;
namespace detail
{
@@ -103,7 +105,8 @@ namespace detail
{
std::shared_ptr<rt::profiles::value_type> construct_model(bp::object const & tp, bp::object dtype, sc::driver::CommandQueue & queue)
{
return std::shared_ptr<rt::profiles::value_type>(new rt::profiles::value_type(tools::extract_template_type(tp), tools::extract_dtype(dtype), (sc::templates::base const &)bp::extract<sc::templates::base>(tp), queue));
tpt::base* raw = bp::extract<tpt::base*>(tp);
return std::make_shared<rt::profiles::value_type>(tools::extract_template_type(tp), tools::extract_dtype(dtype), raw->getptr(), queue);
}
std::shared_ptr<sc::array>

View File

@@ -58,8 +58,8 @@ void export_templates()
//Base
{
#define __PROP(name) .def_readonly(#name, &tpt::base::parameters_type::name)
bp::class_<tpt::base, boost::noncopyable>("base", bp::no_init)
#define __PROP(name) .def_readonly(#name, &tpt::base::name)
bp::class_<tpt::base, std::shared_ptr<tpt::base>, boost::noncopyable>("base", bp::no_init)
.def("lmem_usage", &tpt::base::lmem_usage)
.def("registers_usage", &tpt::base::registers_usage)
.def("is_invalid", &tpt::base::is_invalid)
@@ -68,26 +68,23 @@ void export_templates()
#undef __PROP
}
#define WRAP_BASE(name) bp::class_<tpt::base_impl<tpt::name, tpt::name::parameters_type>, bp::bases<tpt::base>, boost::noncopyable>(#name, bp::no_init)\
.add_property("ls0", &tpt::base_impl<tpt::name, tpt::name::parameters_type>::ls0)\
.add_property("ls1", &tpt::base_impl<tpt::name, tpt::name::parameters_type>::ls1);
bp::class_<tpt::base_impl, bp::bases<tpt::base>, boost::noncopyable>("base_impl", bp::no_init)
.add_property("ls0", &tpt::base_impl::ls0)
.add_property("ls1", &tpt::base_impl::ls1);
#define WRAP_TEMPLATE(name, basename, ...) bp::class_<tpt::name, bp::bases<tpt::base_impl<tpt::basename, tpt::basename::parameters_type> > >(#name, bp::init<__VA_ARGS__>())\
#define WRAP_BASE(name) bp::class_<tpt::name, bp::bases<tpt::base_impl>, boost::noncopyable>(#name, bp::no_init);
#define WRAP_TEMPLATE(name, basename, ...) bp::class_<tpt::name, std::shared_ptr<tpt::name>, bp::bases<basename>>(#name, bp::init<__VA_ARGS__>())\
;
#define WRAP_SINGLE_TEMPLATE(name, ...) WRAP_BASE(name) WRAP_TEMPLATE(name, name, __VA_ARGS__)
//Vector AXPY
WRAP_SINGLE_TEMPLATE(elementwise_1d, uint, uint, uint, tpt::fetch_type)
WRAP_SINGLE_TEMPLATE(elementwise_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
WRAP_SINGLE_TEMPLATE(reduce_1d, uint, uint, uint, tpt::fetch_type)
WRAP_TEMPLATE(elementwise_1d, tpt::base_impl, uint, uint, uint, tpt::fetch_type)
WRAP_TEMPLATE(elementwise_2d, tpt::base_impl, uint, uint, uint, uint, uint, tpt::fetch_type)
WRAP_TEMPLATE(reduce_1d, tpt::base_impl, uint, uint, uint, tpt::fetch_type)
WRAP_BASE(reduce_2d)
WRAP_TEMPLATE(reduce_2d_rows, reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
WRAP_TEMPLATE(reduce_2d_cols, reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
WRAP_TEMPLATE(reduce_2d_rows, tpt::reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
WRAP_TEMPLATE(reduce_2d_cols, tpt::reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
WRAP_BASE(gemm)
WRAP_TEMPLATE(gemm_nn, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
WRAP_TEMPLATE(gemm_tn, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
WRAP_TEMPLATE(gemm_nt, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
WRAP_TEMPLATE(gemm_tt, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
WRAP_TEMPLATE(gemm_nn, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
WRAP_TEMPLATE(gemm_tn, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
WRAP_TEMPLATE(gemm_nt, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
WRAP_TEMPLATE(gemm_tt, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
}