Python: Fixed wrapper issues induced after cleaning
This commit is contained in:
@@ -71,7 +71,7 @@ static const int TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE = -18;
|
||||
static const int TEMPLATE_TEMPORARY_TOO_LARGE = -19;
|
||||
static const int TEMPLATE_BLOCK_SIZE_TOO_LARGE = -20;
|
||||
|
||||
class base
|
||||
class base: public std::enable_shared_from_this<base>
|
||||
{
|
||||
private:
|
||||
virtual std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const = 0;
|
||||
@@ -85,6 +85,9 @@ public:
|
||||
virtual int is_invalid(expression_tree const & expressions, driver::Device const & device) const = 0;
|
||||
virtual void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & expressions) = 0;
|
||||
std::string generate(std::string const & suffix, expression_tree const & expressions, driver::Device const & device);
|
||||
std::shared_ptr<base> getptr() {
|
||||
return shared_from_this();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
@@ -36,8 +36,7 @@ private:
|
||||
int is_invalid_impl(driver::Device const &, expression_tree const &) const;
|
||||
std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const;
|
||||
public:
|
||||
elementwise_2d(parameters_type const & parameters);
|
||||
elementwise_2d(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2, fetch_type fetch);
|
||||
elementwise_2d(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch);
|
||||
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
|
||||
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
|
||||
private:
|
||||
|
@@ -32,7 +32,7 @@ namespace isaac
|
||||
namespace templates
|
||||
{
|
||||
|
||||
class reduce_2d : public base_impl<reduce_2d, reduce_2d_parameters>
|
||||
class reduce_2d : public base_impl
|
||||
{
|
||||
protected:
|
||||
reduce_2d(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch, operation_type_family);
|
||||
|
@@ -73,12 +73,11 @@ base_impl::base_impl(uint32_t vwidth, int_t ls0, int_t ls1): vwidth_(vwidth), ls
|
||||
{ }
|
||||
|
||||
uint32_t base_impl::ls0() const
|
||||
{ return p_.ls0; }
|
||||
{ return ls0_; }
|
||||
|
||||
uint32_t base_impl::ls1() const
|
||||
{ return p_.ls1; }
|
||||
{ return ls1_; }
|
||||
|
||||
template<class TType, class PType>
|
||||
int base_impl::is_invalid(expression_tree const & expressions, driver::Device const & device) const
|
||||
{
|
||||
//Query device informations
|
||||
@@ -90,16 +89,16 @@ int base_impl::is_invalid(expression_tree const & expressions, driver::Device c
|
||||
//Invalid work group size
|
||||
size_t max_workgroup_size = device.max_work_group_size();
|
||||
std::vector<size_t> max_work_item_sizes = device.max_work_item_sizes();
|
||||
if (p_.ls0*p_.ls1 > max_workgroup_size)
|
||||
if (ls0_*ls1_ > max_workgroup_size)
|
||||
return TEMPLATE_WORK_GROUP_SIZE_OVERFLOW;
|
||||
if (p_.ls0 > max_work_item_sizes[0])
|
||||
if (ls0_ > max_work_item_sizes[0])
|
||||
return TEMPLATE_LOCAL_SIZE_0_OVERFLOW;
|
||||
|
||||
if (p_.ls1 > max_work_item_sizes[1])
|
||||
if (ls1_ > max_work_item_sizes[1])
|
||||
return TEMPLATE_LOCAL_SIZE_1_OVERFLOW;
|
||||
|
||||
//Invalid SIMD Width
|
||||
if (p_.vwidth!=1 && p_.vwidth!=2 && p_.vwidth!=3 && p_.vwidth!=4)
|
||||
if (vwidth_!=1 && vwidth_!=2 && vwidth_!=3 && vwidth_!=4)
|
||||
return TEMPLATE_INVALID_SIMD_WIDTH;
|
||||
|
||||
return is_invalid_impl(device, expressions);
|
||||
|
@@ -38,7 +38,7 @@ namespace templates
|
||||
|
||||
int elementwise_1d::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{
|
||||
if (p_.fetch==FETCH_FROM_LOCAL)
|
||||
if (fetch_==FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
@@ -57,7 +57,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
|
||||
case driver::CUDA:
|
||||
stream << "#include \"vector.h\"" << std::endl; break;
|
||||
case driver::OPENCL:
|
||||
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << "," << p_.ls1 << ",1)))" << std::endl; break;
|
||||
stream << " __attribute__((reqd_work_group_size(" << ls0_ << "," << ls1_ << ",1)))" << std::endl; break;
|
||||
}
|
||||
|
||||
stream << "$KERNEL void elementwise_1d" << suffix << "($SIZE_T N, " << tools::join(kernel_arguments(device, symbols, tree), ", ") << ")";
|
||||
@@ -75,7 +75,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
|
||||
stream.inc_tab();
|
||||
}
|
||||
|
||||
element_wise_loop_1D(stream, p_.fetch, p_.vwidth, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t vwidth)
|
||||
element_wise_loop_1D(stream, fetch_, vwidth_, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t vwidth)
|
||||
{
|
||||
std::string dtype = append_width("#scalartype",vwidth);
|
||||
|
||||
@@ -111,7 +111,7 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
|
||||
}
|
||||
|
||||
elementwise_1d::elementwise_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch):
|
||||
base_impl(vwidth,ls), ng_(ng), fetch_(fetch)
|
||||
base_impl(vwidth,ls,1), ng_(ng), fetch_(fetch)
|
||||
{}
|
||||
|
||||
|
||||
@@ -130,8 +130,8 @@ void elementwise_1d::enqueue(driver::CommandQueue &, driver::Program const & pro
|
||||
name += suffix;
|
||||
driver::Kernel kernel(program, name.c_str());
|
||||
//NDRange
|
||||
driver::NDRange global(p_.ls0*p_.ng);
|
||||
driver::NDRange local(p_.ls0);
|
||||
driver::NDRange global(ls0_*ng_);
|
||||
driver::NDRange local(ls0_);
|
||||
//Arguments
|
||||
uint32_t current_arg = 0;
|
||||
kernel.setSizeArg(current_arg++, size);
|
||||
|
@@ -35,9 +35,9 @@ namespace templates
|
||||
|
||||
int elementwise_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{
|
||||
if (p_.vwidth>1)
|
||||
if (vwidth_>1)
|
||||
return TEMPLATE_INVALID_SIMD_WIDTH;
|
||||
if(p_.fetch==FETCH_FROM_LOCAL)
|
||||
if(fetch_==FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
@@ -60,7 +60,7 @@ std::string elementwise_2d::generate_impl(std::string const & suffix, expression
|
||||
case driver::CUDA:
|
||||
stream << "#include \"vector.h\"" << std::endl; break;
|
||||
case driver::OPENCL:
|
||||
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << "," << p_.ls1 << ",1)))" << std::endl; break;
|
||||
stream << " __attribute__((reqd_work_group_size(" << ls0_ << "," << ls1_ << ",1)))" << std::endl; break;
|
||||
}
|
||||
|
||||
stream << "$KERNEL void elementwise_2d" << suffix << "($SIZE_T M, $SIZE_T N, " << tools::join(kernel_arguments(device, symbols, tree), ", ") << ")" << std::endl;
|
||||
@@ -68,11 +68,11 @@ std::string elementwise_2d::generate_impl(std::string const & suffix, expression
|
||||
stream.inc_tab();
|
||||
|
||||
|
||||
fetching_loop_info(p_.fetch, "M", stream, init0, upper_bound0, inc0, "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device);
|
||||
fetching_loop_info(fetch_, "M", stream, init0, upper_bound0, inc0, "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device);
|
||||
stream << "for($SIZE_T i = " << init0 << "; i < " << upper_bound0 << "; i += " << inc0 << ")" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
fetching_loop_info(p_.fetch, "N", stream, init1, upper_bound1, inc1, "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device);
|
||||
fetching_loop_info(fetch_, "N", stream, init1, upper_bound1, inc1, "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device);
|
||||
stream << "for($SIZE_T j = " << init1 << "; j < " << upper_bound1 << "; j += " << inc1 << ")" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
@@ -119,8 +119,8 @@ void elementwise_2d::enqueue(driver::CommandQueue & /*queue*/, driver::Program c
|
||||
std::string name = "elementwise_2d";
|
||||
name +=suffix;
|
||||
driver::Kernel kernel(program, name.c_str());
|
||||
driver::NDRange global(p_.ls0*p_.ng0, p_.ls1*p_.ng1);
|
||||
driver::NDRange local(p_.ls0, p_.ls1);
|
||||
driver::NDRange global(ls0_*ng0_, ls1_*ng1_);
|
||||
driver::NDRange local(ls0_, ls1_);
|
||||
uint32_t current_arg = 0;
|
||||
std::vector<int_t> MN = input_sizes(expressions);
|
||||
kernel.setSizeArg(current_arg++, MN[0]);
|
||||
|
@@ -40,10 +40,10 @@ namespace templates
|
||||
uint32_t gemm::lmem_usage(expression_tree const & expression) const
|
||||
{
|
||||
uint32_t N = 0;
|
||||
size_t llda = (A_trans_=='N')?p_.mL:p_.kL+1;
|
||||
size_t lnda = (A_trans_=='N')?p_.kL:p_.mL;
|
||||
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL+1;
|
||||
size_t lndb = (B_trans_=='T')?p_.kL:p_.nL;
|
||||
size_t llda = (A_trans_=='N')?mL_:kL_+1;
|
||||
size_t lnda = (A_trans_=='N')?kL_:mL_;
|
||||
size_t lldb = (B_trans_=='T')?nL_:kL_+1;
|
||||
size_t lndb = (B_trans_=='T')?kL_:nL_;
|
||||
N += llda*lnda;
|
||||
N += lldb*lndb;
|
||||
return N*size_of(expression.dtype());
|
||||
@@ -51,7 +51,7 @@ namespace templates
|
||||
|
||||
uint32_t gemm::registers_usage(expression_tree const & expression) const
|
||||
{
|
||||
uint32_t N = p_.mS * p_.nS + p_.mS * p_.kS + p_.kS * p_.nS;
|
||||
uint32_t N = mS_ * nS_ + mS_ * kS_ + kS_ * nS_;
|
||||
return N*size_of(expression.dtype());
|
||||
}
|
||||
|
||||
@@ -59,51 +59,51 @@ namespace templates
|
||||
{
|
||||
std::vector<int_t> MNK = input_sizes(expressions);
|
||||
int_t M = MNK[0]; int_t N = MNK[1];
|
||||
if(p_.depth > 1)
|
||||
return M*N*p_.depth;
|
||||
if(depth_ > 1)
|
||||
return M*N*depth_;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int gemm::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{
|
||||
if(p_.Afetch!=FETCH_FROM_LOCAL || p_.Bfetch!=FETCH_FROM_LOCAL)
|
||||
if(Afetch_!=FETCH_FROM_LOCAL || Bfetch_!=FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
|
||||
if ((p_.mS % p_.vwidth) > 0 || (p_.nS % p_.vwidth) > 0)
|
||||
if ((mS_ % vwidth_) > 0 || (nS_ % vwidth_) > 0)
|
||||
return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE;
|
||||
|
||||
if(p_.mL > 256 || p_.nL > 256)
|
||||
if(mL_ > 256 || nL_ > 256)
|
||||
return TEMPLATE_BLOCK_SIZE_TOO_LARGE;
|
||||
|
||||
if ( p_.kS % p_.kL == 0)
|
||||
if ( kS_ % kL_ == 0)
|
||||
return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL;
|
||||
|
||||
if (p_.Afetch==FETCH_FROM_LOCAL || p_.Bfetch==FETCH_FROM_LOCAL){
|
||||
if ((p_.lf0*p_.lf1) !=(p_.ls0*p_.ls1))
|
||||
if (Afetch_==FETCH_FROM_LOCAL || Bfetch_==FETCH_FROM_LOCAL){
|
||||
if ((lf0_*lf1_) !=(ls0_*ls1_))
|
||||
return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT;
|
||||
}
|
||||
|
||||
if (p_.Afetch==FETCH_FROM_LOCAL)
|
||||
if (Afetch_==FETCH_FROM_LOCAL)
|
||||
{
|
||||
uint32_t bound1 = (A_trans_=='N')?p_.kL:p_.mL;
|
||||
uint32_t bound0 = (A_trans_=='N')?p_.mL:p_.kL;
|
||||
uint32_t bound1 = (A_trans_=='N')?kL_:mL_;
|
||||
uint32_t bound0 = (A_trans_=='N')?mL_:kL_;
|
||||
|
||||
if (p_.lf1>0 && (bound1 % p_.lf1)> 0)
|
||||
if (lf1_>0 && (bound1 % lf1_)> 0)
|
||||
return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
|
||||
|
||||
if (p_.lf0>0 && (bound0 % (p_.lf0*p_.vwidth)) > 0)
|
||||
if (lf0_>0 && (bound0 % (lf0_*vwidth_)) > 0)
|
||||
return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_0_MUST_BE_NL_MULTIPLE:TEMPLATE_LOCAL_FETCH_0_MUST_BE_KL_MULTIPLE;
|
||||
|
||||
}
|
||||
if (p_.Bfetch==FETCH_FROM_LOCAL)
|
||||
if (Bfetch_==FETCH_FROM_LOCAL)
|
||||
{
|
||||
uint32_t bound1 = (B_trans_=='T')?p_.kL:p_.nL;
|
||||
uint32_t bound0 = (B_trans_=='T')?p_.nL:p_.kL;
|
||||
uint32_t bound1 = (B_trans_=='T')?kL_:nL_;
|
||||
uint32_t bound0 = (B_trans_=='T')?nL_:kL_;
|
||||
|
||||
if (p_.lf1>0 && (bound1 % p_.lf1)> 0)
|
||||
if (lf1_>0 && (bound1 % lf1_)> 0)
|
||||
return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
|
||||
|
||||
if (p_.lf0>0 && (bound0 % (p_.lf0*p_.vwidth)) > 0)
|
||||
if (lf0_>0 && (bound0 % (lf0_*vwidth_)) > 0)
|
||||
return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
|
||||
|
||||
}
|
||||
@@ -117,10 +117,10 @@ namespace templates
|
||||
using tools::to_string;
|
||||
|
||||
driver::backend_type backend = device.backend();
|
||||
bool has_depth = p_.depth > 1;
|
||||
#define VLOAD(offset, ptr) vload(p_.vwidth, sdtype, offset, ptr, "1", backend, true)
|
||||
#define VLOAD_MISALIGNED(offset, ptr) vload(p_.vwidth, sdtype, offset, ptr, "1", backend, false)
|
||||
#define VSTORE(value, offset, ptr) vstore(p_.vwidth, sdtype, value, offset, ptr, "1", backend)
|
||||
bool has_depth = depth_ > 1;
|
||||
#define VLOAD(offset, ptr) vload(vwidth_, sdtype, offset, ptr, "1", backend, true)
|
||||
#define VLOAD_MISALIGNED(offset, ptr) vload(vwidth_, sdtype, offset, ptr, "1", backend, false)
|
||||
#define VSTORE(value, offset, ptr) vstore(vwidth_, sdtype, value, offset, ptr, "1", backend)
|
||||
|
||||
symbolic::preset::gemm::args args;
|
||||
infos(tree, args);
|
||||
@@ -134,7 +134,7 @@ namespace templates
|
||||
kernel_generation_stream stream(backend);
|
||||
numeric_type dtype = tree.dtype();
|
||||
std::string sdtype = to_string(dtype);
|
||||
std::string vdtype = append_width(sdtype, p_.vwidth);
|
||||
std::string vdtype = append_width(sdtype, vwidth_);
|
||||
|
||||
//////////////////
|
||||
/// DECLARATIONS
|
||||
@@ -148,7 +148,7 @@ namespace templates
|
||||
switch(backend)
|
||||
{
|
||||
case driver::OPENCL:
|
||||
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << "," << p_.ls1 << ",1)))" << std::endl;
|
||||
stream << " __attribute__((reqd_work_group_size(" << ls0_ << "," << ls1_ << ",1)))" << std::endl;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@@ -166,20 +166,20 @@ namespace templates
|
||||
|
||||
///Declare
|
||||
stream << "//blocks" << std::endl;
|
||||
stream << sdtype << " rC[" << p_.mS << "][" << p_.nS << "] = {{0}};" << std::endl;
|
||||
stream << vdtype << " rA[" << p_.kS << "][" << p_.mS/p_.vwidth << "];" << std::endl;
|
||||
stream << vdtype << " rB[" << p_.kS << "][" << p_.nS/p_.vwidth << "];" << std::endl;
|
||||
stream << sdtype << " rC[" << mS_ << "][" << nS_ << "] = {{0}};" << std::endl;
|
||||
stream << vdtype << " rA[" << kS_ << "][" << mS_/vwidth_ << "];" << std::endl;
|
||||
stream << vdtype << " rB[" << kS_ << "][" << nS_/vwidth_ << "];" << std::endl;
|
||||
stream << std::endl;
|
||||
|
||||
stream << "//pointers" << std::endl;
|
||||
size_t llda = (A_trans_=='N')?p_.mL:p_.kL+1;
|
||||
size_t lnda = (A_trans_=='N')?p_.kL:p_.mL;
|
||||
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL+1;
|
||||
size_t lndb = (B_trans_=='T')?p_.kL:p_.nL;
|
||||
size_t llda = (A_trans_=='N')?mL_:kL_+1;
|
||||
size_t lnda = (A_trans_=='N')?kL_:mL_;
|
||||
size_t lldb = (B_trans_=='T')?nL_:kL_+1;
|
||||
size_t lndb = (B_trans_=='T')?kL_:nL_;
|
||||
stream << "$LOCAL " << sdtype << " lA[" << llda*lnda << "];" << std::endl;
|
||||
stream << "$LOCAL " << sdtype << " lB[" << lldb*lndb << "];" << std::endl;
|
||||
uint32_t npA = p_.mL/(A_trans_=='N'?p_.lf0*p_.vwidth:p_.lf1);
|
||||
uint32_t npB = p_.nL/(B_trans_=='T'?p_.lf0*p_.vwidth:p_.lf1);
|
||||
uint32_t npA = mL_/(A_trans_=='N'?lf0_*vwidth_:lf1_);
|
||||
uint32_t npB = nL_/(B_trans_=='T'?lf0_*vwidth_:lf1_);
|
||||
stream << "$GLOBAL " << sdtype << "* Ai[" << npA << "];" << std::endl;
|
||||
stream << "$GLOBAL " << sdtype << "* Bi[" << npB << "];" << std::endl;
|
||||
stream << std::endl;
|
||||
@@ -204,20 +204,20 @@ namespace templates
|
||||
if(has_depth)
|
||||
{
|
||||
stream << "gidz = $GROUP_IDX_2;" << std::endl;
|
||||
stream << "div = (K+" << p_.depth-1 << ")/" << p_.depth << ";" << std::endl;
|
||||
stream << "div = (K+" << depth_-1 << ")/" << depth_ << ";" << std::endl;
|
||||
stream << "offz = div*gidz;" << std::endl;
|
||||
stream << "K = min(K - div*gidz, ($SIZE_T)div);" << std::endl;
|
||||
}
|
||||
|
||||
stream << "idt = " << p_.ls0 << "*ids.w + ids.z;" << std::endl;
|
||||
stream << "idT.y = idt/" << p_.lf0 << ";" << std::endl;
|
||||
stream << "idT.x = idt - " << p_.lf0 << "*idT.y;" << std::endl;
|
||||
stream << "idt = " << ls0_ << "*ids.w + ids.z;" << std::endl;
|
||||
stream << "idT.y = idt/" << lf0_ << ";" << std::endl;
|
||||
stream << "idT.x = idt - " << lf0_ << "*idT.y;" << std::endl;
|
||||
stream << std::endl;
|
||||
|
||||
stream << "//Adjust pointers and bounds per work-item" << std::endl;
|
||||
stream << "ids.x *= " << p_.mL << ";" << std::endl;
|
||||
stream << "ids.y *= " << p_.nL << ";" << std::endl;
|
||||
stream << "idT.x *= " << p_.vwidth << ";" << std::endl;
|
||||
stream << "ids.x *= " << mL_ << ";" << std::endl;
|
||||
stream << "ids.y *= " << nL_ << ";" << std::endl;
|
||||
stream << "idT.x *= " << vwidth_ << ";" << std::endl;
|
||||
|
||||
stream << "M -= ids.x;" << std::endl;
|
||||
if(A_trans_=='N')
|
||||
@@ -280,19 +280,19 @@ namespace templates
|
||||
|
||||
for(uint32_t i = 0 ; i < npA ; i++ )
|
||||
if (A_trans_=='N')
|
||||
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.lf0*p_.vwidth) + " < M", "(int)((idT.x + " + to_string(i*p_.lf0*p_.vwidth) + ")" + ASTRIDE1 + ")", "0") << ";" << std::endl;
|
||||
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*lf0_*vwidth_) + " < M", "(int)((idT.x + " + to_string(i*lf0_*vwidth_) + ")" + ASTRIDE1 + ")", "0") << ";" << std::endl;
|
||||
else
|
||||
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.lf1) + " < M", "(int)((idT.y + " + to_string(i*p_.lf1) + ")*lda)", "0") << ";" << std::endl;
|
||||
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*lf1_) + " < M", "(int)((idT.y + " + to_string(i*lf1_) + ")*lda)", "0") << ";" << std::endl;
|
||||
|
||||
for(uint32_t i = 0 ; i < npB ; i++ )
|
||||
if (B_trans_=='T')
|
||||
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.lf0*p_.vwidth) + " < N", "(int)((idT.x + " + to_string(i*p_.lf0*p_.vwidth) + ")" + BSTRIDE1 + ")", "0") << ";" << std::endl;
|
||||
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*lf0_*vwidth_) + " < N", "(int)((idT.x + " + to_string(i*lf0_*vwidth_) + ")" + BSTRIDE1 + ")", "0") << ";" << std::endl;
|
||||
else
|
||||
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.lf1) + " < N", "(int)((idT.y + " + to_string(i*p_.lf1) + ")*ldb)", "0") << ";" << std::endl;
|
||||
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*lf1_) + " < N", "(int)((idT.y + " + to_string(i*lf1_) + ")*ldb)", "0") << ";" << std::endl;
|
||||
|
||||
stream << std::endl;
|
||||
stream << "//Outer loop" << std::endl;
|
||||
stream << "while(K >=" << p_.kL << ")" << std::endl;
|
||||
stream << "while(K >=" << kL_ << ")" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
@@ -306,13 +306,13 @@ namespace templates
|
||||
stream << "//Fetch A to local memory" << std::endl;
|
||||
if (A_trans_=='N')
|
||||
{
|
||||
for(uint32_t k = 0; k < p_.kL; k += p_.lf1)
|
||||
for(uint32_t m = 0; m < p_.mL; m += p_.lf0*p_.vwidth)
|
||||
for(uint32_t k = 0; k < kL_; k += lf1_)
|
||||
for(uint32_t m = 0; m < mL_; m += lf0_*vwidth_)
|
||||
{
|
||||
std::string mm = to_string(m/(p_.vwidth*p_.lf0));
|
||||
std::string mm = to_string(m/(vwidth_*lf0_));
|
||||
std::string kk = to_string(k);
|
||||
if(last_iteration)
|
||||
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
|
||||
for(uint32_t s = 0 ; s < vwidth_ ; ++s)
|
||||
stream << "ldsA[" << k*llda + m + s << "] = (condy" << k << " && " << s << "< M)? Ai[" << mm << "][" << k << "*lda + " << s << "] : 0;" << std::endl;
|
||||
else
|
||||
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Ai[" + mm +"][" + kk + "*lda]"), "0", "ldsA + " + to_string(k*llda+m)) << ";" << std::endl;
|
||||
@@ -320,13 +320,13 @@ namespace templates
|
||||
}
|
||||
else
|
||||
{
|
||||
for(uint32_t k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
|
||||
for(uint32_t m = 0; m < p_.mL; m += p_.lf1)
|
||||
for(uint32_t k = 0; k < kL_; k += lf0_*vwidth_)
|
||||
for(uint32_t m = 0; m < mL_; m += lf1_)
|
||||
{
|
||||
std::string mm = to_string(m/p_.lf1);
|
||||
std::string mm = to_string(m/lf1_);
|
||||
std::string kk = to_string(k);
|
||||
if(last_iteration)
|
||||
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
|
||||
for(uint32_t s = 0 ; s < vwidth_ ; ++s)
|
||||
stream << "ldsA[" << m*llda + k + s << "] = condx" << k + s << "? Ai[" << mm << "][" << k + s << ASTRIDE1 << "] : 0;" << std::endl;
|
||||
|
||||
else
|
||||
@@ -337,13 +337,13 @@ namespace templates
|
||||
stream << "//Fetch B to local memory" << std::endl;
|
||||
if (B_trans_=='T')
|
||||
{
|
||||
for(uint32_t k = 0; k < p_.kL; k += p_.lf1)
|
||||
for(uint32_t n = 0; n < p_.nL; n += p_.lf0*p_.vwidth)
|
||||
for(uint32_t k = 0; k < kL_; k += lf1_)
|
||||
for(uint32_t n = 0; n < nL_; n += lf0_*vwidth_)
|
||||
{
|
||||
std::string nn = to_string(n/(p_.vwidth*p_.lf0));
|
||||
std::string nn = to_string(n/(vwidth_*lf0_));
|
||||
std::string kk = to_string(k);
|
||||
if(last_iteration)
|
||||
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
|
||||
for(uint32_t s = 0 ; s < vwidth_ ; ++s)
|
||||
stream << "ldsB[" << k*lldb + n + s << "] = (condy" << k << " && " << s << "< N)? Bi[" << nn << "][" << kk << "*ldb +" << s << "] : 0;" << std::endl;
|
||||
else
|
||||
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Bi[" + nn +"][" + kk + "*ldb]"), "0", "ldsB + " + to_string(k*lldb+n)) << ";" << std::endl;
|
||||
@@ -351,13 +351,13 @@ namespace templates
|
||||
}
|
||||
else
|
||||
{
|
||||
for(uint32_t k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
|
||||
for(uint32_t n = 0; n < p_.nL; n += p_.lf1)
|
||||
for(uint32_t k = 0; k < kL_; k += lf0_*vwidth_)
|
||||
for(uint32_t n = 0; n < nL_; n += lf1_)
|
||||
{
|
||||
std::string nn = to_string(n/p_.lf1);
|
||||
std::string nn = to_string(n/lf1_);
|
||||
std::string kk = to_string(k);
|
||||
if(last_iteration)
|
||||
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
|
||||
for(uint32_t s = 0 ; s < vwidth_ ; ++s)
|
||||
stream << "ldsB[" << n*lldb + k + s << "] = condx" << k + s << "? Bi[" << nn << "][" << k + s << BSTRIDE1 << "] : 0;" << std::endl;
|
||||
|
||||
else
|
||||
@@ -366,18 +366,18 @@ namespace templates
|
||||
}
|
||||
|
||||
if(A_trans_=='N')
|
||||
stream << "ldsA = lA + ids.z*" << p_.vwidth << ";" << std::endl;
|
||||
stream << "ldsA = lA + ids.z*" << vwidth_ << ";" << std::endl;
|
||||
else
|
||||
stream << "ldsA = lA + ids.z*" << llda*p_.vwidth << ";" << std::endl;
|
||||
stream << "ldsA = lA + ids.z*" << llda*vwidth_ << ";" << std::endl;
|
||||
|
||||
if(B_trans_=='T')
|
||||
stream << "ldsB = lB + ids.w*" << p_.vwidth << ";" << std::endl;
|
||||
stream << "ldsB = lB + ids.w*" << vwidth_ << ";" << std::endl;
|
||||
else
|
||||
stream << "ldsB = lB + ids.w*" << lldb*p_.vwidth << ";" << std::endl;
|
||||
stream << "ldsB = lB + ids.w*" << lldb*vwidth_ << ";" << std::endl;
|
||||
|
||||
stream << "$LOCAL_BARRIER;" << std::endl;
|
||||
std::string bound = last_iteration?"K":tools::to_string(p_.kL);
|
||||
size_t ks = last_iteration?1:p_.kS;
|
||||
std::string bound = last_iteration?"K":tools::to_string(kL_);
|
||||
size_t ks = last_iteration?1:kS_;
|
||||
stream << "//Inner loop" << std::endl;
|
||||
stream << "for(uint32_t k = 0; k < " << bound << "; k+=" << ks << "){" << std::endl;
|
||||
stream.inc_tab();
|
||||
@@ -385,19 +385,19 @@ namespace templates
|
||||
stream << "//Fetch A to registers" << std::endl;
|
||||
stream << "#pragma unroll" << std::endl;
|
||||
stream << "for(uint32_t kk = 0; kk < " << ks << "; kk++)" << std::endl;
|
||||
stream << "#pragma unroll " << p_.mS/p_.vwidth << std::endl;
|
||||
stream << "for(uint32_t mm = 0; mm < " << p_.mS/p_.vwidth << "; mm++)" << std::endl;
|
||||
stream << "#pragma unroll " << mS_/vwidth_ << std::endl;
|
||||
stream << "for(uint32_t mm = 0; mm < " << mS_/vwidth_ << "; mm++)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
if(A_trans_=='N')
|
||||
stream << "rA[kk][mm] = " << VLOAD("0", "ldsA + k*" + to_string(llda) + " + mm*" + to_string(p_.ls0*p_.vwidth) + "+ kk*" + to_string(llda)) << ";" << std::endl;
|
||||
stream << "rA[kk][mm] = " << VLOAD("0", "ldsA + k*" + to_string(llda) + " + mm*" + to_string(ls0_*vwidth_) + "+ kk*" + to_string(llda)) << ";" << std::endl;
|
||||
else
|
||||
{
|
||||
if(p_.vwidth==1)
|
||||
stream << "rA[kk][mm] = ldsA[k + mm*" << p_.ls0*llda << "+ kk" << "];" << std::endl;
|
||||
if(vwidth_==1)
|
||||
stream << "rA[kk][mm] = ldsA[k + mm*" << ls0_*llda << "+ kk" << "];" << std::endl;
|
||||
else
|
||||
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
|
||||
stream << access_vector_type("rA[kk][mm]", s) << " = ldsA[k + (mm*" << p_.vwidth*p_.ls0 << " + " << s << ")*" << llda << "+ kk];" << std::endl;
|
||||
for(uint32_t s = 0 ; s < vwidth_ ; ++s)
|
||||
stream << access_vector_type("rA[kk][mm]", s) << " = ldsA[k + (mm*" << vwidth_*ls0_ << " + " << s << ")*" << llda << "+ kk];" << std::endl;
|
||||
}
|
||||
|
||||
stream.dec_tab();
|
||||
@@ -406,19 +406,19 @@ namespace templates
|
||||
stream << "//Fetch B to registers" << std::endl;
|
||||
stream << "#pragma unroll " << ks << std::endl;
|
||||
stream << "for(uint32_t kk = 0; kk < " << ks << "; kk++)" << std::endl;
|
||||
stream << "#pragma unroll " << p_.nS/p_.vwidth << std::endl;
|
||||
stream << "for(uint32_t nn = 0; nn < " << p_.nS/p_.vwidth << "; nn++)" << std::endl;
|
||||
stream << "#pragma unroll " << nS_/vwidth_ << std::endl;
|
||||
stream << "for(uint32_t nn = 0; nn < " << nS_/vwidth_ << "; nn++)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
if(B_trans_=='T')
|
||||
stream << "rB[kk][nn] = " << VLOAD("0", "ldsB + k*" + to_string(lldb) + " + nn*" + to_string(p_.ls1*p_.vwidth) + "+ kk*" + to_string(lldb)) << ";" << std::endl;
|
||||
stream << "rB[kk][nn] = " << VLOAD("0", "ldsB + k*" + to_string(lldb) + " + nn*" + to_string(ls1_*vwidth_) + "+ kk*" + to_string(lldb)) << ";" << std::endl;
|
||||
else
|
||||
{
|
||||
if(p_.vwidth==1)
|
||||
stream << "rB[kk][nn] = ldsB[k" << " + nn*" << p_.ls1*lldb << "+ kk" << "];" << std::endl;
|
||||
if(vwidth_==1)
|
||||
stream << "rB[kk][nn] = ldsB[k" << " + nn*" << ls1_*lldb << "+ kk" << "];" << std::endl;
|
||||
else
|
||||
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
|
||||
stream << access_vector_type("rB[kk][nn]", s) << " = ldsB[k" << " + (nn*" << p_.vwidth*p_.ls1 << " + " << s << ")*" << lldb << "+ kk];" << std::endl;
|
||||
for(uint32_t s = 0 ; s < vwidth_ ; ++s)
|
||||
stream << access_vector_type("rB[kk][nn]", s) << " = ldsB[k" << " + (nn*" << vwidth_*ls1_ << " + " << s << ")*" << lldb << "+ kk];" << std::endl;
|
||||
}
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
@@ -427,41 +427,41 @@ namespace templates
|
||||
stream << "#pragma unroll" << std::endl;
|
||||
stream << "for(uint32_t kk = 0 ; kk < " << ks << "; ++kk){" << std::endl;
|
||||
stream.inc_tab();
|
||||
for(uint32_t nn=0; nn < p_.nS; ++nn)
|
||||
for(uint32_t mm=0; mm < p_.mS; ++mm){
|
||||
for(uint32_t nn=0; nn < nS_; ++nn)
|
||||
for(uint32_t mm=0; mm < mS_; ++mm){
|
||||
string res_str, lhs_str, rhs_str;
|
||||
res_str = "rC[" + to_string(mm) + "][" + to_string(nn) + "]";
|
||||
if (p_.vwidth==1)
|
||||
if (vwidth_==1)
|
||||
lhs_str = "rA[kk][" + to_string(mm) + "]";
|
||||
else
|
||||
lhs_str = access_vector_type("rA[kk][" + to_string(mm/p_.vwidth) + "]", mm%p_.vwidth);
|
||||
if (p_.vwidth==1)
|
||||
lhs_str = access_vector_type("rA[kk][" + to_string(mm/vwidth_) + "]", mm%vwidth_);
|
||||
if (vwidth_==1)
|
||||
rhs_str = "rB[kk]["+to_string(nn)+"]";
|
||||
else
|
||||
rhs_str = access_vector_type("rB[kk]["+to_string(nn/p_.vwidth)+"]", nn%p_.vwidth);
|
||||
rhs_str = access_vector_type("rB[kk]["+to_string(nn/vwidth_)+"]", nn%vwidth_);
|
||||
stream << res_str << "= $MAD(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl;
|
||||
}
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
stream << "K -= " << p_.kL << ";" << std::endl;
|
||||
stream << "K -= " << kL_ << ";" << std::endl;
|
||||
|
||||
//Increment A pointers to global memory
|
||||
if (A_trans_=='N')
|
||||
for(uint32_t i = 0 ; i < npA ; ++i)
|
||||
stream << "Ai[" << i << "] += " << p_.kL << "*lda;" << std::endl;
|
||||
stream << "Ai[" << i << "] += " << kL_ << "*lda;" << std::endl;
|
||||
else
|
||||
for(uint32_t i = 0 ; i < npA ; ++i)
|
||||
stream << "Ai[" << i << "] += " << p_.kL << ASTRIDE1 << ";" << std::endl;
|
||||
stream << "Ai[" << i << "] += " << kL_ << ASTRIDE1 << ";" << std::endl;
|
||||
|
||||
//Increment B pointers to global memory
|
||||
if (B_trans_=='T')
|
||||
for(uint32_t i = 0 ; i < npB ; ++i)
|
||||
stream << "Bi[" << i << "] += " << p_.kL << "*ldb;" << std::endl;
|
||||
stream << "Bi[" << i << "] += " << kL_ << "*ldb;" << std::endl;
|
||||
else
|
||||
for(uint32_t i = 0 ; i < npB ; ++i)
|
||||
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
|
||||
stream << "Bi[" << i << "] += " << kL_ << BSTRIDE1 << ";" << std::endl;
|
||||
};
|
||||
fetch_to_lds(false);
|
||||
stream.dec_tab();
|
||||
@@ -471,15 +471,15 @@ namespace templates
|
||||
if(A_trans_=='N' || B_trans_=='T')
|
||||
{
|
||||
stream << "int Ky = K - idT.y;" << std::endl;
|
||||
for(uint32_t k = 0; k < p_.kL; k += p_.lf1)
|
||||
for(uint32_t k = 0; k < kL_; k += lf1_)
|
||||
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
|
||||
}
|
||||
|
||||
if(A_trans_=='T' || B_trans_=='N')
|
||||
{
|
||||
stream << "int Kx = K - idT.x;" << std::endl;
|
||||
for(uint32_t k = 0 ; k < p_.kL ; k += p_.lf0*p_.vwidth)
|
||||
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
|
||||
for(uint32_t k = 0 ; k < kL_ ; k += lf0_*vwidth_)
|
||||
for(uint32_t s = 0 ; s < vwidth_ ; ++s)
|
||||
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
|
||||
}
|
||||
fetch_to_lds(true);
|
||||
@@ -498,35 +498,35 @@ namespace templates
|
||||
stream << "N += ids.y;" << std::endl;
|
||||
|
||||
stream << "C += ids.x" << CSTRIDE1 << ";" << std::endl;
|
||||
stream << "C += ids.z*" << p_.vwidth << CSTRIDE1 << ";" << std::endl;
|
||||
stream << "C += ids.z*" << vwidth_ << CSTRIDE1 << ";" << std::endl;
|
||||
stream << "C += ids.y*ldc;" << std::endl;
|
||||
stream << "C += ids.w*" << p_.vwidth << "*ldc;" << std::endl;
|
||||
stream << "C += ids.w*" << vwidth_ << "*ldc;" << std::endl;
|
||||
if(has_depth)
|
||||
stream << "C += gidz*ldc*N;" << std::endl;
|
||||
|
||||
stream << "M -= ids.x;" << std::endl;
|
||||
stream << "M -= ids.z*" << p_.vwidth << ";" << std::endl;
|
||||
stream << "M -= ids.z*" << vwidth_ << ";" << std::endl;
|
||||
|
||||
stream << "N -= ids.y;" << std::endl;
|
||||
stream << "N -= ids.w*" << p_.vwidth << ";" << std::endl;
|
||||
stream << "N -= ids.w*" << vwidth_ << ";" << std::endl;
|
||||
|
||||
for(uint32_t n=0; n < p_.nS; ++n)
|
||||
for(uint32_t n=0; n < nS_; ++n)
|
||||
{
|
||||
string Cj = to_string((n/p_.vwidth)*(p_.ls1*p_.vwidth) + n%p_.vwidth);
|
||||
string Cj = to_string((n/vwidth_)*(ls1_*vwidth_) + n%vwidth_);
|
||||
stream << "if(" << Cj << " >= N) return;" << std::endl;
|
||||
for(uint32_t m=0; m < p_.mS; ++m)
|
||||
for(uint32_t m=0; m < mS_; ++m)
|
||||
stream << "rC[" << m << "][" << n << "] *= alpha;" << std::endl;
|
||||
for(uint32_t m=0; m < p_.mS; ++m)
|
||||
for(uint32_t m=0; m < mS_; ++m)
|
||||
{
|
||||
string Ci = to_string((m/p_.vwidth)*(p_.ls0*p_.vwidth) + m%p_.vwidth);
|
||||
string Ci = to_string((m/vwidth_)*(ls0_*vwidth_) + m%vwidth_);
|
||||
stream << "if(" << Ci << "< M) ";
|
||||
if(has_depth)
|
||||
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "];" << std::endl;
|
||||
else
|
||||
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "] + ((beta != (" << sdtype << ")0)?(beta*" << "C[" << Ci << CSTRIDE1 << "]):0);" << std::endl;
|
||||
}
|
||||
if((n+1)%p_.vwidth==0){
|
||||
stream << "C += ldc*" << p_.ls1*p_.vwidth - p_.vwidth + 1 << ";" << std::endl;
|
||||
if((n+1)%vwidth_==0){
|
||||
stream << "C += ldc*" << ls1_*vwidth_ - vwidth_ + 1 << ";" << std::endl;
|
||||
}
|
||||
else{
|
||||
stream << "C += ldc;" << std::endl;
|
||||
@@ -594,8 +594,8 @@ namespace templates
|
||||
reduce_name += suffix;
|
||||
|
||||
driver::Kernel gemm(program, gemm_name.c_str());
|
||||
driver::NDRange local(p_.ls0, p_.ls1, 1);
|
||||
driver::NDRange global(align(align(M,p_.mS)/p_.mS, p_.ls0), align(align(N,p_.nS)/p_.nS, p_.ls1), p_.depth);
|
||||
driver::NDRange local(ls0_, ls1_, 1);
|
||||
driver::NDRange global(align(align(M,mS_)/mS_, ls0_), align(align(N,nS_)/nS_, ls1_), depth_);
|
||||
|
||||
uint32_t current_arg = 0;
|
||||
|
||||
@@ -603,7 +603,7 @@ namespace templates
|
||||
gemm.setSizeArg(current_arg++, M);
|
||||
gemm.setSizeArg(current_arg++, N);
|
||||
gemm.setSizeArg(current_arg++, K);
|
||||
if(p_.depth==1)
|
||||
if(depth_==1)
|
||||
{
|
||||
if(backend==driver::OPENCL)
|
||||
gemm.setArg(current_arg++, C.array.handle.cl);
|
||||
@@ -642,15 +642,15 @@ namespace templates
|
||||
gemm.setArg(current_arg++, beta);
|
||||
options.enqueue(program.context(), gemm, global, local);
|
||||
|
||||
if(p_.depth > 1)
|
||||
if(depth_ > 1)
|
||||
{
|
||||
uint32_t current_arg = 0;
|
||||
driver::Kernel reduce(program, reduce_name.c_str());
|
||||
driver::NDRange local(p_.ls0, p_.ls1);
|
||||
driver::NDRange global(align(M, p_.ls0), align(N, p_.ls1));
|
||||
driver::NDRange local(ls0_, ls1_);
|
||||
driver::NDRange global(align(M, ls0_), align(N, ls1_));
|
||||
reduce.setSizeArg(current_arg++, M);
|
||||
reduce.setSizeArg(current_arg++, N);
|
||||
reduce.setSizeArg(current_arg++, p_.depth);
|
||||
reduce.setSizeArg(current_arg++, depth_);
|
||||
reduce.setArg(current_arg++, workspace);
|
||||
reduce.setSizeArg(current_arg++, M);
|
||||
if(backend==driver::OPENCL)
|
||||
@@ -682,7 +682,7 @@ namespace templates
|
||||
,int_t ms, int_t ks, int_t ns
|
||||
,fetch_type Afetch , fetch_type Bfetch
|
||||
,int_t lf0, int_t lf1, char A_trans, char B_trans) :
|
||||
base_impl(vwidth, ls0, ls1), kL_(kL), depth_(D), mS_(ms), kS_(ks), nS_(ns),
|
||||
base_impl(vwidth, ls0, ls1), mL_(ms*ls0), kL_(kL), nL_(ns*ls1), depth_(D), mS_(ms), kS_(ks), nS_(ns),
|
||||
Afetch_(Afetch), Bfetch_(Bfetch), lf0_(lf0), lf1_(lf1), A_trans_(A_trans), B_trans_(B_trans)
|
||||
{
|
||||
if(A_trans_=='N' && B_trans_=='N') type_ = GEMM_NN;
|
||||
|
@@ -38,20 +38,20 @@ namespace templates
|
||||
|
||||
uint32_t reduce_1d::lmem_usage(expression_tree const & x) const
|
||||
{
|
||||
return p_.ls0*size_of(x.dtype());
|
||||
return ls0_*size_of(x.dtype());
|
||||
}
|
||||
|
||||
int reduce_1d::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{
|
||||
if (p_.fetch==FETCH_FROM_LOCAL)
|
||||
if (fetch_==FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
|
||||
uint32_t reduce_1d::temporary_workspace(expression_tree const &) const
|
||||
{
|
||||
if(p_.ng > 1)
|
||||
return p_.ng;
|
||||
if(ng_ > 1)
|
||||
return ng_;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -99,13 +99,13 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
||||
if (is_indexing(rd->op().type))
|
||||
{
|
||||
stream << rd->process("$GLOBAL uint* #name_temp = ($GLOBAL uint *)(tmp + " + tools::to_string(offset) + ");");
|
||||
offset += 4*p_.ng;
|
||||
offset += 4*ng_;
|
||||
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp_value = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + ");");
|
||||
offset += size_of(dtype)*p_.ng;
|
||||
offset += size_of(dtype)*ng_;
|
||||
}
|
||||
else{
|
||||
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + ");");
|
||||
offset += size_of(dtype)*p_.ng;
|
||||
offset += size_of(dtype)*ng_;
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -118,7 +118,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
||||
case driver::CUDA:
|
||||
stream << "#include \"vector.h\"" << std::endl; break;
|
||||
case driver::OPENCL:
|
||||
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << ",1,1)))" << std::endl; break;
|
||||
stream << " __attribute__((reqd_work_group_size(" << ls0_ << ",1,1)))" << std::endl; break;
|
||||
}
|
||||
stream << "$KERNEL void prod" << suffix << "($SIZE_T N, $GLOBAL char* tmp," << tools::join(kernel_arguments(device, symbols, tree), ", ") << ")" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
@@ -135,18 +135,18 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
||||
{
|
||||
if(is_indexing(rd->op().type))
|
||||
{
|
||||
stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(p_.ls0) + "];") << std::endl;
|
||||
stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(ls0_) + "];") << std::endl;
|
||||
stream << rd->process("#scalartype #name_acc_value = " + neutral_element(rd->op(), backend, "#scalartype") + ";") << std::endl;
|
||||
stream << rd->process("$LOCAL uint32_t #name_buf[" + tools::to_string(p_.ls0) + "];") << std::endl;
|
||||
stream << rd->process("$LOCAL uint32_t #name_buf[" + tools::to_string(ls0_) + "];") << std::endl;
|
||||
stream << rd->process("uint32_t #name_acc = 0;") << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
stream << rd->process("$LOCAL #scalartype #name_buf[" + tools::to_string(p_.ls0) + "];") << std::endl;
|
||||
stream << rd->process("$LOCAL #scalartype #name_buf[" + tools::to_string(ls0_) + "];") << std::endl;
|
||||
stream << rd->process("#scalartype #name_acc = " + neutral_element(rd->op(), backend, "#scalartype") + ";") << std::endl;
|
||||
}
|
||||
}
|
||||
element_wise_loop_1D(stream, p_.fetch, p_.vwidth, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t vwidth)
|
||||
element_wise_loop_1D(stream, fetch_, vwidth_, "i", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t vwidth)
|
||||
{
|
||||
std::string dtype = append_width("#scalartype",vwidth);
|
||||
//Fetch vector entry
|
||||
@@ -174,7 +174,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
||||
stream << rd->process("#name_buf[lid] = #name_acc;") << std::endl;
|
||||
}
|
||||
//Reduce local memory
|
||||
reduce_1d_local_memory(stream, p_.ls0, reductions, "#name_buf", "#name_buf_value", backend);
|
||||
reduce_1d_local_memory(stream, ls0_, reductions, "#name_buf", "#name_buf_value", backend);
|
||||
//Write to temporary buffers
|
||||
stream << "if (lid==0)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
@@ -205,19 +205,19 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
||||
{
|
||||
if (is_indexing(rd->op().type))
|
||||
{
|
||||
stream << rd->process("$LOCAL uint32_t #name_buf[" + tools::to_string(p_.ls0) + "];");
|
||||
stream << rd->process("$LOCAL uint32_t #name_buf[" + tools::to_string(ls0_) + "];");
|
||||
stream << rd->process("uint32_t #name_acc = 0;") << std::endl;
|
||||
stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(p_.ls0) + "];") << std::endl;
|
||||
stream << rd->process("$LOCAL #scalartype #name_buf_value[" + tools::to_string(ls0_) + "];") << std::endl;
|
||||
stream << rd->process("#scalartype #name_acc_value = " + neutral_element(rd->op(), backend, "#scalartype") + ";");
|
||||
}
|
||||
else
|
||||
{
|
||||
stream << rd->process("$LOCAL #scalartype #name_buf[" + tools::to_string(p_.ls0) + "];") << std::endl;
|
||||
stream << rd->process("$LOCAL #scalartype #name_buf[" + tools::to_string(ls0_) + "];") << std::endl;
|
||||
stream << rd->process("#scalartype #name_acc = " + neutral_element(rd->op(), backend, "#scalartype") + ";");
|
||||
}
|
||||
}
|
||||
//Private reduction
|
||||
stream << "for(uint32_t i = lid; i < " << p_.ng << "; i += lsize)" << std::endl;
|
||||
stream << "for(uint32_t i = lid; i < " << ng_ << "; i += lsize)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
for (symbolic::reduce_1d* rd: reductions)
|
||||
@@ -234,7 +234,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
||||
stream << rd->process("#name_buf[lid] = #name_acc;") << std::endl;
|
||||
}
|
||||
//Local reduction
|
||||
reduce_1d_local_memory(stream, p_.ls0, reductions, "#name_buf", "#name_buf_value", backend);
|
||||
reduce_1d_local_memory(stream, ls0_, reductions, "#name_buf", "#name_buf_value", backend);
|
||||
//Write
|
||||
stream << "if (lid==0)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
@@ -250,7 +250,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
|
||||
}
|
||||
|
||||
reduce_1d::reduce_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch):
|
||||
base_impl(vwidth,ls), ng_(ng), fetch_(fetch)
|
||||
base_impl(vwidth,ls,1), ng_(ng), fetch_(fetch)
|
||||
{}
|
||||
|
||||
std::vector<int_t> reduce_1d::input_sizes(expression_tree const & x) const
|
||||
@@ -275,8 +275,8 @@ void reduce_1d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
|
||||
driver::Kernel kernels[2] = { driver::Kernel(program,name[0].c_str()), driver::Kernel(program,name[1].c_str()) };
|
||||
|
||||
//NDRange
|
||||
driver::NDRange global[2] = { driver::NDRange(p_.ls0*p_.ng), driver::NDRange(p_.ls0) };
|
||||
driver::NDRange local[2] = { driver::NDRange(p_.ls0), driver::NDRange(p_.ls0) };
|
||||
driver::NDRange global[2] = { driver::NDRange(ls0_*ng_), driver::NDRange(ls0_) };
|
||||
driver::NDRange local[2] = { driver::NDRange(ls0_), driver::NDRange(ls0_) };
|
||||
//Arguments
|
||||
for (auto & kernel : kernels)
|
||||
{
|
||||
|
@@ -41,22 +41,22 @@ namespace templates
|
||||
|
||||
int reduce_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{
|
||||
if (p_.fetch==FETCH_FROM_LOCAL)
|
||||
if (fetch_==FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
|
||||
uint32_t reduce_2d::lmem_usage(const expression_tree&) const
|
||||
{
|
||||
return (p_.ls0+1)*p_.ls1;
|
||||
return (ls0_+1)*ls1_;
|
||||
}
|
||||
|
||||
uint32_t reduce_2d::temporary_workspace(expression_tree const & expressions) const
|
||||
{
|
||||
std::vector<int_t> MN = input_sizes(expressions);
|
||||
int_t M = MN[0];
|
||||
if(p_.ng0 > 1)
|
||||
return M*p_.ng0;
|
||||
if(ng0_ > 1)
|
||||
return M*ng0_;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
||||
name[0] += suffix;
|
||||
name[1] += suffix;
|
||||
|
||||
uint32_t ldls = p_.ls0;
|
||||
uint32_t ldls = ls0_;
|
||||
std::string ls0ldstr = to_string(ldls);
|
||||
|
||||
auto unroll_tmp = [&]()
|
||||
@@ -87,13 +87,13 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
||||
if (is_indexing(rd->op().type))
|
||||
{
|
||||
stream << rd->process("$GLOBAL uint* #name_temp = ($GLOBAL uint*)(tmp + " + tools::to_string(offset) + "*M);");
|
||||
offset += 4*p_.ng0;
|
||||
offset += 4*ng0_;
|
||||
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp_value = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
|
||||
offset += size_of(dtype)*p_.ng0;
|
||||
offset += size_of(dtype)*ng0_;
|
||||
}
|
||||
else{
|
||||
stream << rd->process("$GLOBAL " + sdtype + "* #name_temp = ($GLOBAL " + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
|
||||
offset += size_of(dtype)*p_.ng0;
|
||||
offset += size_of(dtype)*ng0_;
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -107,7 +107,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
||||
stream << "#include \"vector.h\"" << std::endl;
|
||||
break;
|
||||
case driver::OPENCL:
|
||||
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << "," << p_.ls1 << ",1)))" << std::endl;
|
||||
stream << " __attribute__((reqd_work_group_size(" << ls0_ << "," << ls1_ << ",1)))" << std::endl;
|
||||
break;
|
||||
}
|
||||
stream << "$KERNEL void " << name[0] << "($SIZE_T M, $SIZE_T N, $GLOBAL char* tmp, " << tools::join(kernel_arguments(device, symbols, tree), ", ") << ")" << std::endl;
|
||||
@@ -119,13 +119,13 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
||||
stream << "$SIZE_T lidy = $LOCAL_IDX_1;" << std::endl;
|
||||
//Loop r
|
||||
std::ostringstream upper;
|
||||
upper << "(M +" << p_.ls1 - 1 << ")/" << p_.ls1 << "*" << p_.ls1;
|
||||
upper << "(M +" << ls1_ - 1 << ")/" << ls1_ << "*" << ls1_;
|
||||
|
||||
element_wise_loop_1D(stream, p_.fetch, (reduction_type_==REDUCE_ROWS)?1:1, "r", upper.str(), "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device, [&](uint32_t cwidth)
|
||||
element_wise_loop_1D(stream, fetch_, (reduction_type_==REDUCE_ROWS)?1:1, "r", upper.str(), "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device, [&](uint32_t cwidth)
|
||||
{
|
||||
//Declare Buffers
|
||||
for (symbolic::reduce_2d* rd : reductions)
|
||||
stream << rd->process("$LOCAL " + append_width("#scalartype", cwidth) + " #name_buf[" + to_string(p_.ls1*ldls) + "];") << std::endl;
|
||||
stream << rd->process("$LOCAL " + append_width("#scalartype", cwidth) + " #name_buf[" + to_string(ls1_*ldls) + "];") << std::endl;
|
||||
|
||||
//Accumulators
|
||||
for (symbolic::reduce_2d* rd : reductions){
|
||||
@@ -136,7 +136,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
||||
stream << "if (r < M)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
element_wise_loop_1D(stream, p_.fetch, (reduction_type_==REDUCE_COLUMNS)?p_.vwidth:1, "c", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t rwidth)
|
||||
element_wise_loop_1D(stream, fetch_, (reduction_type_==REDUCE_COLUMNS)?vwidth_:1, "c", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t rwidth)
|
||||
{
|
||||
std::string rdtype = append_width("#scalartype", rwidth);
|
||||
std::string cdtype = append_width("#scalartype", cwidth);
|
||||
@@ -167,7 +167,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
||||
stream << rd->process("#name_buf[lidy*" + ls0ldstr + "+ lidx] = #name_acc;") << std::endl;
|
||||
//Reduce local memory
|
||||
stream << "#pragma unroll" << std::endl;
|
||||
stream << "for($SIZE_T stride = " << p_.ls0/2 << "; stride >0; stride /=2)" << std::endl;
|
||||
stream << "for($SIZE_T stride = " << ls0_/2 << "; stride >0; stride /=2)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
stream << "$LOCAL_BARRIER;" << std::endl;
|
||||
@@ -189,7 +189,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
||||
stream << "if (r < M && lidx == 0)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
if(p_.ng0==1)
|
||||
if(ng0_==1)
|
||||
for(size_t idx: assignments)
|
||||
for(size_t s = 0 ; s < cwidth ; ++s)
|
||||
stream << symbols.at(idx)->evaluate({{"leaf", "at(r+" + to_string(s) + ")"},
|
||||
@@ -211,17 +211,17 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
||||
/* ------------------------
|
||||
* Kernel 2
|
||||
* -----------------------*/
|
||||
if(p_.ng0>1)
|
||||
if(ng0_>1)
|
||||
{
|
||||
if(backend==driver::OPENCL)
|
||||
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << "," << p_.ls1 << ",1)))" << std::endl;
|
||||
stream << " __attribute__((reqd_work_group_size(" << ls0_ << "," << ls1_ << ",1)))" << std::endl;
|
||||
stream << "$KERNEL void " << name[1] << "($SIZE_T M, $SIZE_T N , $GLOBAL char* tmp, " << tools::join(kernel_arguments(device, symbols, tree), ", ") << ")" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
unroll_tmp();
|
||||
for (symbolic::reduce_2d* rd : reductions)
|
||||
stream << rd->process("$LOCAL #scalartype #name_buf[" + to_string(p_.ls1*ldls) + "];") << std::endl;
|
||||
stream << "for($SIZE_T r = $GLOBAL_IDX_1; r < (M +" << p_.ls1 - 1 << ")/" << p_.ls1 << "*" << p_.ls1 << "; r += " << GlobalSize1(backend) << "){" << std::endl;
|
||||
stream << rd->process("$LOCAL #scalartype #name_buf[" + to_string(ls1_*ldls) + "];") << std::endl;
|
||||
stream << "for($SIZE_T r = $GLOBAL_IDX_1; r < (M +" << ls1_ - 1 << ")/" << ls1_ << "*" << ls1_ << "; r += " << GlobalSize1(backend) << "){" << std::endl;
|
||||
stream.inc_tab();
|
||||
stream << "$SIZE_T lidx = $LOCAL_IDX_0;" << std::endl;
|
||||
stream << "$SIZE_T lidy = $LOCAL_IDX_1;" << std::endl;
|
||||
@@ -230,7 +230,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
||||
stream << "if (r < M)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
stream << "for($SIZE_T c = lidx; c < " << p_.ng0 << "; c += $LOCAL_SIZE_0){" << std::endl;
|
||||
stream << "for($SIZE_T c = lidx; c < " << ng0_ << "; c += $LOCAL_SIZE_0){" << std::endl;
|
||||
stream.inc_tab();
|
||||
for (symbolic::reduce_2d* rd: reductions)
|
||||
compute_reduce_1d(stream, rd->process("#name_acc"), rd->process("#name_temp[r + M*c]"), rd->op());
|
||||
@@ -241,7 +241,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
|
||||
for (symbolic::reduce_2d* rd : reductions)
|
||||
stream << rd->process("#name_buf[lidy*" + ls0ldstr + "+ lidx] = #name_acc;") << std::endl;
|
||||
stream << "#pragma unroll" << std::endl;
|
||||
stream << "for($SIZE_T stride = " << p_.ls0/2 << "; stride >0; stride /=2)" << std::endl;
|
||||
stream << "for($SIZE_T stride = " << ls0_/2 << "; stride >0; stride /=2)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
stream << "$LOCAL_BARRIER;" << std::endl;
|
||||
@@ -300,7 +300,7 @@ void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
|
||||
name[0] += suffix;
|
||||
name[1] += suffix;
|
||||
|
||||
uint32_t nk = (p_.ng0==1)?1:2;
|
||||
uint32_t nk = (ng0_==1)?1:2;
|
||||
|
||||
std::vector<driver::Kernel> kernels;
|
||||
for(uint32_t k = 0 ; k < nk ; ++k)
|
||||
@@ -319,8 +319,8 @@ void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & pr
|
||||
}
|
||||
|
||||
//NDRange
|
||||
driver::NDRange global[2] = { driver::NDRange(p_.ls0*p_.ng0, p_.ls1*p_.ng1), driver::NDRange(p_.ls0, p_.ls1*p_.ng1) };
|
||||
driver::NDRange local[2] = { driver::NDRange(p_.ls0, p_.ls1), driver::NDRange(p_.ls0, p_.ls1) };
|
||||
driver::NDRange global[2] = { driver::NDRange(ls0_*ng0_, ls1_*ng1_), driver::NDRange(ls0_, ls1_*ng1_) };
|
||||
driver::NDRange local[2] = { driver::NDRange(ls0_, ls1_), driver::NDRange(ls0_, ls1_) };
|
||||
for(uint32_t i = 0 ; i < nk ; ++i)
|
||||
control.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
|
||||
}
|
||||
|
@@ -23,6 +23,8 @@
|
||||
#include "common.hpp"
|
||||
#include "core.h"
|
||||
|
||||
namespace tpt = sc::templates;
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
@@ -103,7 +105,8 @@ namespace detail
|
||||
{
|
||||
std::shared_ptr<rt::profiles::value_type> construct_model(bp::object const & tp, bp::object dtype, sc::driver::CommandQueue & queue)
|
||||
{
|
||||
return std::shared_ptr<rt::profiles::value_type>(new rt::profiles::value_type(tools::extract_template_type(tp), tools::extract_dtype(dtype), (sc::templates::base const &)bp::extract<sc::templates::base>(tp), queue));
|
||||
tpt::base* raw = bp::extract<tpt::base*>(tp);
|
||||
return std::make_shared<rt::profiles::value_type>(tools::extract_template_type(tp), tools::extract_dtype(dtype), raw->getptr(), queue);
|
||||
}
|
||||
|
||||
std::shared_ptr<sc::array>
|
||||
|
@@ -58,8 +58,8 @@ void export_templates()
|
||||
|
||||
//Base
|
||||
{
|
||||
#define __PROP(name) .def_readonly(#name, &tpt::base::parameters_type::name)
|
||||
bp::class_<tpt::base, boost::noncopyable>("base", bp::no_init)
|
||||
#define __PROP(name) .def_readonly(#name, &tpt::base::name)
|
||||
bp::class_<tpt::base, std::shared_ptr<tpt::base>, boost::noncopyable>("base", bp::no_init)
|
||||
.def("lmem_usage", &tpt::base::lmem_usage)
|
||||
.def("registers_usage", &tpt::base::registers_usage)
|
||||
.def("is_invalid", &tpt::base::is_invalid)
|
||||
@@ -68,26 +68,23 @@ void export_templates()
|
||||
#undef __PROP
|
||||
}
|
||||
|
||||
#define WRAP_BASE(name) bp::class_<tpt::base_impl<tpt::name, tpt::name::parameters_type>, bp::bases<tpt::base>, boost::noncopyable>(#name, bp::no_init)\
|
||||
.add_property("ls0", &tpt::base_impl<tpt::name, tpt::name::parameters_type>::ls0)\
|
||||
.add_property("ls1", &tpt::base_impl<tpt::name, tpt::name::parameters_type>::ls1);
|
||||
bp::class_<tpt::base_impl, bp::bases<tpt::base>, boost::noncopyable>("base_impl", bp::no_init)
|
||||
.add_property("ls0", &tpt::base_impl::ls0)
|
||||
.add_property("ls1", &tpt::base_impl::ls1);
|
||||
|
||||
#define WRAP_TEMPLATE(name, basename, ...) bp::class_<tpt::name, bp::bases<tpt::base_impl<tpt::basename, tpt::basename::parameters_type> > >(#name, bp::init<__VA_ARGS__>())\
|
||||
#define WRAP_BASE(name) bp::class_<tpt::name, bp::bases<tpt::base_impl>, boost::noncopyable>(#name, bp::no_init);
|
||||
|
||||
#define WRAP_TEMPLATE(name, basename, ...) bp::class_<tpt::name, std::shared_ptr<tpt::name>, bp::bases<basename>>(#name, bp::init<__VA_ARGS__>())\
|
||||
;
|
||||
#define WRAP_SINGLE_TEMPLATE(name, ...) WRAP_BASE(name) WRAP_TEMPLATE(name, name, __VA_ARGS__)
|
||||
|
||||
//Vector AXPY
|
||||
WRAP_SINGLE_TEMPLATE(elementwise_1d, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_SINGLE_TEMPLATE(elementwise_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_SINGLE_TEMPLATE(reduce_1d, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_TEMPLATE(elementwise_1d, tpt::base_impl, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_TEMPLATE(elementwise_2d, tpt::base_impl, uint, uint, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_TEMPLATE(reduce_1d, tpt::base_impl, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_BASE(reduce_2d)
|
||||
WRAP_TEMPLATE(reduce_2d_rows, reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_TEMPLATE(reduce_2d_cols, reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_TEMPLATE(reduce_2d_rows, tpt::reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_TEMPLATE(reduce_2d_cols, tpt::reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_BASE(gemm)
|
||||
WRAP_TEMPLATE(gemm_nn, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
WRAP_TEMPLATE(gemm_tn, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
WRAP_TEMPLATE(gemm_nt, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
WRAP_TEMPLATE(gemm_tt, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
|
||||
|
||||
WRAP_TEMPLATE(gemm_nn, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
WRAP_TEMPLATE(gemm_tn, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
WRAP_TEMPLATE(gemm_nt, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
WRAP_TEMPLATE(gemm_tt, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
}
|
||||
|
Reference in New Issue
Block a user