Code quality: shortened parameter names in JIT code generator
This commit is contained in:
@@ -37,14 +37,14 @@ namespace isaac
|
||||
namespace templates
|
||||
{
|
||||
|
||||
matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
, unsigned int local_size_0, unsigned int KL, unsigned int local_size_1, unsigned int D
|
||||
matrix_product_parameters::matrix_product_parameters(unsigned int vwidth
|
||||
, unsigned int ls0, unsigned int KL, unsigned int ls1, unsigned int D
|
||||
, unsigned int ms, unsigned int ks, unsigned int ns
|
||||
, fetching_policy_type A_fetching_policy, fetching_policy_type B_fetching_policy
|
||||
, unsigned int local_fetch_0, unsigned int local_fetch_1): base::parameters_type(simd_width, local_size_0, local_size_1, 1),
|
||||
kL(KL), depth(D), mS(ms), kS(ks), nS(ns), A_fetching_policy(A_fetching_policy), B_fetching_policy(B_fetching_policy),
|
||||
local_fetch_0(local_fetch_0), local_fetch_1(local_fetch_1),
|
||||
mL(ms*local_size_0), nL(ns*local_size_1)
|
||||
, fetch_type Afetch, fetch_type Bfetch
|
||||
, unsigned int lf0, unsigned int lf1): base::parameters_type(vwidth, ls0, ls1, 1),
|
||||
kL(KL), depth(D), mS(ms), kS(ks), nS(ns), Afetch(Afetch), Bfetch(Bfetch),
|
||||
lf0(lf0), lf1(lf1),
|
||||
mL(ms*ls0), nL(ns*ls1)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -74,13 +74,10 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
|
||||
int matrix_product::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
||||
{
|
||||
// if(device.vendor()==driver::Device::Vendor::NVIDIA && p_.simd_width > 1)
|
||||
// return TEMPLATE_INVALID_SIMD_WIDTH;
|
||||
|
||||
if(p_.A_fetching_policy!=FETCH_FROM_LOCAL || p_.B_fetching_policy!=FETCH_FROM_LOCAL)
|
||||
if(p_.Afetch!=FETCH_FROM_LOCAL || p_.Bfetch!=FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
|
||||
if ((p_.mS % p_.simd_width) > 0 || (p_.nS % p_.simd_width) > 0)
|
||||
if ((p_.mS % p_.vwidth) > 0 || (p_.nS % p_.vwidth) > 0)
|
||||
return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE;
|
||||
|
||||
if(p_.mL > 256 || p_.nL > 256)
|
||||
@@ -89,32 +86,32 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
if ( p_.kS % p_.kL == 0)
|
||||
return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL;
|
||||
|
||||
if (p_.A_fetching_policy==FETCH_FROM_LOCAL || p_.B_fetching_policy==FETCH_FROM_LOCAL){
|
||||
if ((p_.local_fetch_0*p_.local_fetch_1) !=(p_.local_size_0*p_.local_size_1))
|
||||
if (p_.Afetch==FETCH_FROM_LOCAL || p_.Bfetch==FETCH_FROM_LOCAL){
|
||||
if ((p_.lf0*p_.lf1) !=(p_.ls0*p_.ls1))
|
||||
return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT;
|
||||
}
|
||||
|
||||
if (p_.A_fetching_policy==FETCH_FROM_LOCAL)
|
||||
if (p_.Afetch==FETCH_FROM_LOCAL)
|
||||
{
|
||||
unsigned int bound1 = (A_trans_=='N')?p_.kL:p_.mL;
|
||||
unsigned int bound0 = (A_trans_=='N')?p_.mL:p_.kL;
|
||||
|
||||
if (p_.local_fetch_1>0 && (bound1 % p_.local_fetch_1)> 0)
|
||||
if (p_.lf1>0 && (bound1 % p_.lf1)> 0)
|
||||
return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
|
||||
|
||||
if (p_.local_fetch_0>0 && (bound0 % (p_.local_fetch_0*p_.simd_width)) > 0)
|
||||
if (p_.lf0>0 && (bound0 % (p_.lf0*p_.vwidth)) > 0)
|
||||
return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_0_MUST_BE_NL_MULTIPLE:TEMPLATE_LOCAL_FETCH_0_MUST_BE_KL_MULTIPLE;
|
||||
|
||||
}
|
||||
if (p_.B_fetching_policy==FETCH_FROM_LOCAL)
|
||||
if (p_.Bfetch==FETCH_FROM_LOCAL)
|
||||
{
|
||||
unsigned int bound1 = (B_trans_=='T')?p_.kL:p_.nL;
|
||||
unsigned int bound0 = (B_trans_=='T')?p_.nL:p_.kL;
|
||||
|
||||
if (p_.local_fetch_1>0 && (bound1 % p_.local_fetch_1)> 0)
|
||||
if (p_.lf1>0 && (bound1 % p_.lf1)> 0)
|
||||
return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
|
||||
|
||||
if (p_.local_fetch_0>0 && (bound0 % (p_.local_fetch_0*p_.simd_width)) > 0)
|
||||
if (p_.lf0>0 && (bound0 % (p_.lf0*p_.vwidth)) > 0)
|
||||
return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
|
||||
|
||||
}
|
||||
@@ -129,9 +126,9 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
|
||||
driver::backend_type backend = device.backend();
|
||||
bool has_depth = p_.depth > 1;
|
||||
#define VLOAD(offset, ptr) vload(p_.simd_width, sdtype, offset, ptr, "1", backend, true)
|
||||
#define VLOAD_MISALIGNED(offset, ptr) vload(p_.simd_width, sdtype, offset, ptr, "1", backend, false)
|
||||
#define VSTORE(value, offset, ptr) vstore(p_.simd_width, sdtype, value, offset, ptr, "1", backend)
|
||||
#define VLOAD(offset, ptr) vload(p_.vwidth, sdtype, offset, ptr, "1", backend, true)
|
||||
#define VLOAD_MISALIGNED(offset, ptr) vload(p_.vwidth, sdtype, offset, ptr, "1", backend, false)
|
||||
#define VSTORE(value, offset, ptr) vstore(p_.vwidth, sdtype, value, offset, ptr, "1", backend)
|
||||
|
||||
symbolic::preset::matrix_product::args args;
|
||||
infos(tree, args);
|
||||
@@ -145,7 +142,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
kernel_generation_stream stream(backend);
|
||||
numeric_type dtype = tree.dtype();
|
||||
std::string sdtype = to_string(dtype);
|
||||
std::string vdtype = append_width(sdtype, p_.simd_width);
|
||||
std::string vdtype = append_width(sdtype, p_.vwidth);
|
||||
|
||||
//////////////////
|
||||
/// DECLARATIONS
|
||||
@@ -159,7 +156,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
switch(backend)
|
||||
{
|
||||
case driver::OPENCL:
|
||||
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl;
|
||||
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << "," << p_.ls1 << ",1)))" << std::endl;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@@ -178,8 +175,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
///Declare
|
||||
stream << "//blocks" << std::endl;
|
||||
stream << sdtype << " rC[" << p_.mS << "][" << p_.nS << "] = {{0}};" << std::endl;
|
||||
stream << vdtype << " rA[" << p_.kS << "][" << p_.mS/p_.simd_width << "];" << std::endl;
|
||||
stream << vdtype << " rB[" << p_.kS << "][" << p_.nS/p_.simd_width << "];" << std::endl;
|
||||
stream << vdtype << " rA[" << p_.kS << "][" << p_.mS/p_.vwidth << "];" << std::endl;
|
||||
stream << vdtype << " rB[" << p_.kS << "][" << p_.nS/p_.vwidth << "];" << std::endl;
|
||||
stream << std::endl;
|
||||
|
||||
stream << "//pointers" << std::endl;
|
||||
@@ -187,8 +184,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL;
|
||||
stream << "$LOCAL " << sdtype << " lA[" << p_.kL*p_.mL << "];" << std::endl;
|
||||
stream << "$LOCAL " << sdtype << " lB[" << p_.kL*p_.nL << "];" << std::endl;
|
||||
unsigned int npA = p_.mL/(A_trans_=='N'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
|
||||
unsigned int npB = p_.nL/(B_trans_=='T'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
|
||||
unsigned int npA = p_.mL/(A_trans_=='N'?p_.lf0*p_.vwidth:p_.lf1);
|
||||
unsigned int npB = p_.nL/(B_trans_=='T'?p_.lf0*p_.vwidth:p_.lf1);
|
||||
stream << "$GLOBAL " << sdtype << "* Ai[" << npA << "];" << std::endl;
|
||||
stream << "$GLOBAL " << sdtype << "* Bi[" << npB << "];" << std::endl;
|
||||
stream << std::endl;
|
||||
@@ -218,15 +215,15 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
stream << "K = min(K - div*gidz, ($SIZE_T)div);" << std::endl;
|
||||
}
|
||||
|
||||
stream << "idt = " << p_.local_size_0 << "*ids.w + ids.z;" << std::endl;
|
||||
stream << "idT.y = idt/" << p_.local_fetch_0 << ";" << std::endl;
|
||||
stream << "idT.x = idt - " << p_.local_fetch_0 << "*idT.y;" << std::endl;
|
||||
stream << "idt = " << p_.ls0 << "*ids.w + ids.z;" << std::endl;
|
||||
stream << "idT.y = idt/" << p_.lf0 << ";" << std::endl;
|
||||
stream << "idT.x = idt - " << p_.lf0 << "*idT.y;" << std::endl;
|
||||
stream << std::endl;
|
||||
|
||||
stream << "//Adjust pointers and bounds per work-item" << std::endl;
|
||||
stream << "ids.x *= " << p_.mL << ";" << std::endl;
|
||||
stream << "ids.y *= " << p_.nL << ";" << std::endl;
|
||||
stream << "idT.x *= " << p_.simd_width << ";" << std::endl;
|
||||
stream << "idT.x *= " << p_.vwidth << ";" << std::endl;
|
||||
|
||||
stream << "M -= ids.x;" << std::endl;
|
||||
if(A_trans_=='N')
|
||||
@@ -289,15 +286,15 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
|
||||
for(unsigned int i = 0 ; i < npA ; i++ )
|
||||
if (A_trans_=='N')
|
||||
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.local_fetch_0*p_.simd_width) + " < M", "(int)((idT.x + " + to_string(i*p_.local_fetch_0*p_.simd_width) + ")" + ASTRIDE1 + ")", "0") << ";" << std::endl;
|
||||
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.lf0*p_.vwidth) + " < M", "(int)((idT.x + " + to_string(i*p_.lf0*p_.vwidth) + ")" + ASTRIDE1 + ")", "0") << ";" << std::endl;
|
||||
else
|
||||
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.local_fetch_1) + " < M", "(int)((idT.y + " + to_string(i*p_.local_fetch_1) + ")*lda)", "0") << ";" << std::endl;
|
||||
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.lf1) + " < M", "(int)((idT.y + " + to_string(i*p_.lf1) + ")*lda)", "0") << ";" << std::endl;
|
||||
|
||||
for(unsigned int i = 0 ; i < npB ; i++ )
|
||||
if (B_trans_=='T')
|
||||
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.local_fetch_0*p_.simd_width) + " < N", "(int)((idT.x + " + to_string(i*p_.local_fetch_0*p_.simd_width) + ")" + BSTRIDE1 + ")", "0") << ";" << std::endl;
|
||||
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.lf0*p_.vwidth) + " < N", "(int)((idT.x + " + to_string(i*p_.lf0*p_.vwidth) + ")" + BSTRIDE1 + ")", "0") << ";" << std::endl;
|
||||
else
|
||||
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.local_fetch_1) + " < N", "(int)((idT.y + " + to_string(i*p_.local_fetch_1) + ")*ldb)", "0") << ";" << std::endl;
|
||||
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.lf1) + " < N", "(int)((idT.y + " + to_string(i*p_.lf1) + ")*ldb)", "0") << ";" << std::endl;
|
||||
|
||||
stream << std::endl;
|
||||
stream << "//Outer loop" << std::endl;
|
||||
@@ -315,13 +312,13 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
stream << "//Fetch A to local memory" << std::endl;
|
||||
if (A_trans_=='N')
|
||||
{
|
||||
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_1)
|
||||
for(unsigned int m = 0; m < p_.mL; m += p_.local_fetch_0*p_.simd_width)
|
||||
for(unsigned int k = 0; k < p_.kL; k += p_.lf1)
|
||||
for(unsigned int m = 0; m < p_.mL; m += p_.lf0*p_.vwidth)
|
||||
{
|
||||
std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0));
|
||||
std::string mm = to_string(m/(p_.vwidth*p_.lf0));
|
||||
std::string kk = to_string(k);
|
||||
if(last_iteration)
|
||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
||||
stream << "ldsA[" << k*llda + m + s << "] = (condy" << k << " && " << s << "< M)? Ai[" << mm << "][" << k << "*lda + " << s << "] : 0;" << std::endl;
|
||||
else
|
||||
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Ai[" + mm +"][" + kk + "*lda]"), "0", "ldsA + " + to_string(k*llda+m)) << ";" << std::endl;
|
||||
@@ -329,13 +326,13 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
}
|
||||
else
|
||||
{
|
||||
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_0*p_.simd_width)
|
||||
for(unsigned int m = 0; m < p_.mL; m += p_.local_fetch_1)
|
||||
for(unsigned int k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
|
||||
for(unsigned int m = 0; m < p_.mL; m += p_.lf1)
|
||||
{
|
||||
std::string mm = to_string(m/p_.local_fetch_1);
|
||||
std::string mm = to_string(m/p_.lf1);
|
||||
std::string kk = to_string(k);
|
||||
if(last_iteration)
|
||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
||||
stream << "ldsA[" << m*llda + k + s << "] = condx" << k + s << "? Ai[" << mm << "][" << k + s << ASTRIDE1 << "] : 0;" << std::endl;
|
||||
|
||||
else
|
||||
@@ -346,13 +343,13 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
stream << "//Fetch B to local memory" << std::endl;
|
||||
if (B_trans_=='T')
|
||||
{
|
||||
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_1)
|
||||
for(unsigned int n = 0; n < p_.nL; n += p_.local_fetch_0*p_.simd_width)
|
||||
for(unsigned int k = 0; k < p_.kL; k += p_.lf1)
|
||||
for(unsigned int n = 0; n < p_.nL; n += p_.lf0*p_.vwidth)
|
||||
{
|
||||
std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0));
|
||||
std::string nn = to_string(n/(p_.vwidth*p_.lf0));
|
||||
std::string kk = to_string(k);
|
||||
if(last_iteration)
|
||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
||||
stream << "ldsB[" << k*lldb + n + s << "] = (condy" << k << " && " << s << "< N)? Bi[" << nn << "][" << kk << "*ldb +" << s << "] : 0;" << std::endl;
|
||||
else
|
||||
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Bi[" + nn +"][" + kk + "*ldb]"), "0", "ldsB + " + to_string(k*lldb+n)) << ";" << std::endl;
|
||||
@@ -360,13 +357,13 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
}
|
||||
else
|
||||
{
|
||||
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_0*p_.simd_width)
|
||||
for(unsigned int n = 0; n < p_.nL; n += p_.local_fetch_1)
|
||||
for(unsigned int k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
|
||||
for(unsigned int n = 0; n < p_.nL; n += p_.lf1)
|
||||
{
|
||||
std::string nn = to_string(n/p_.local_fetch_1);
|
||||
std::string nn = to_string(n/p_.lf1);
|
||||
std::string kk = to_string(k);
|
||||
if(last_iteration)
|
||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
||||
stream << "ldsB[" << n*lldb + k + s << "] = condx" << k + s << "? Bi[" << nn << "][" << k + s << BSTRIDE1 << "] : 0;" << std::endl;
|
||||
|
||||
else
|
||||
@@ -375,14 +372,14 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
}
|
||||
|
||||
if(A_trans_=='N')
|
||||
stream << "ldsA = lA + ids.z*" << p_.simd_width << ";" << std::endl;
|
||||
stream << "ldsA = lA + ids.z*" << p_.vwidth << ";" << std::endl;
|
||||
else
|
||||
stream << "ldsA = lA + ids.z*" << llda*p_.simd_width << ";" << std::endl;
|
||||
stream << "ldsA = lA + ids.z*" << llda*p_.vwidth << ";" << std::endl;
|
||||
|
||||
if(B_trans_=='T')
|
||||
stream << "ldsB = lB + ids.w*" << p_.simd_width << ";" << std::endl;
|
||||
stream << "ldsB = lB + ids.w*" << p_.vwidth << ";" << std::endl;
|
||||
else
|
||||
stream << "ldsB = lB + ids.w*" << lldb*p_.simd_width << ";" << std::endl;
|
||||
stream << "ldsB = lB + ids.w*" << lldb*p_.vwidth << ";" << std::endl;
|
||||
|
||||
stream << "$LOCAL_BARRIER;" << std::endl;
|
||||
|
||||
@@ -393,19 +390,19 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
stream << "//Fetch A to registers" << std::endl;
|
||||
stream << "#pragma unroll" << std::endl;
|
||||
stream << "for(unsigned int kk = 0; kk < " << p_.kS << "; kk++)" << std::endl;
|
||||
stream << "#pragma unroll " << p_.mS/p_.simd_width << std::endl;
|
||||
stream << "for(unsigned int mm = 0; mm < " << p_.mS/p_.simd_width << "; mm++)" << std::endl;
|
||||
stream << "#pragma unroll " << p_.mS/p_.vwidth << std::endl;
|
||||
stream << "for(unsigned int mm = 0; mm < " << p_.mS/p_.vwidth << "; mm++)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
if(A_trans_=='N')
|
||||
stream << "rA[kk][mm] = " << VLOAD("0", "ldsA + k*" + to_string(llda) + " + mm*" + to_string(p_.local_size_0*p_.simd_width) + "+ kk*" + to_string(llda)) << ";" << std::endl;
|
||||
stream << "rA[kk][mm] = " << VLOAD("0", "ldsA + k*" + to_string(llda) + " + mm*" + to_string(p_.ls0*p_.vwidth) + "+ kk*" + to_string(llda)) << ";" << std::endl;
|
||||
else
|
||||
{
|
||||
if(p_.simd_width==1)
|
||||
stream << "rA[kk][mm] = ldsA[k + mm*" << p_.local_size_0*llda << "+ kk" << "];" << std::endl;
|
||||
if(p_.vwidth==1)
|
||||
stream << "rA[kk][mm] = ldsA[k + mm*" << p_.ls0*llda << "+ kk" << "];" << std::endl;
|
||||
else
|
||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||
stream << access_vector_type("rA[kk][mm]", s) << " = ldsA[k + (mm*" << p_.simd_width*p_.local_size_0 << " + " << s << ")*" << llda << "+ kk];" << std::endl;
|
||||
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
||||
stream << access_vector_type("rA[kk][mm]", s) << " = ldsA[k + (mm*" << p_.vwidth*p_.ls0 << " + " << s << ")*" << llda << "+ kk];" << std::endl;
|
||||
}
|
||||
|
||||
stream.dec_tab();
|
||||
@@ -414,19 +411,19 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
stream << "//Fetch B to registers" << std::endl;
|
||||
stream << "#pragma unroll " << p_.kS << std::endl;
|
||||
stream << "for(unsigned int kk = 0; kk < " << p_.kS << "; kk++)" << std::endl;
|
||||
stream << "#pragma unroll " << p_.nS/p_.simd_width << std::endl;
|
||||
stream << "for(unsigned int nn = 0; nn < " << p_.nS/p_.simd_width << "; nn++)" << std::endl;
|
||||
stream << "#pragma unroll " << p_.nS/p_.vwidth << std::endl;
|
||||
stream << "for(unsigned int nn = 0; nn < " << p_.nS/p_.vwidth << "; nn++)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
if(B_trans_=='T')
|
||||
stream << "rB[kk][nn] = " << VLOAD("0", "ldsB + k*" + to_string(lldb) + " + nn*" + to_string(p_.local_size_1*p_.simd_width) + "+ kk*" + to_string(lldb)) << ";" << std::endl;
|
||||
stream << "rB[kk][nn] = " << VLOAD("0", "ldsB + k*" + to_string(lldb) + " + nn*" + to_string(p_.ls1*p_.vwidth) + "+ kk*" + to_string(lldb)) << ";" << std::endl;
|
||||
else
|
||||
{
|
||||
if(p_.simd_width==1)
|
||||
stream << "rB[kk][nn] = ldsB[k" << " + nn*" << p_.local_size_1*lldb << "+ kk" << "];" << std::endl;
|
||||
if(p_.vwidth==1)
|
||||
stream << "rB[kk][nn] = ldsB[k" << " + nn*" << p_.ls1*lldb << "+ kk" << "];" << std::endl;
|
||||
else
|
||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||
stream << access_vector_type("rB[kk][nn]", s) << " = ldsB[k" << " + (nn*" << p_.simd_width*p_.local_size_1 << " + " << s << ")*" << lldb << "+ kk];" << std::endl;
|
||||
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
||||
stream << access_vector_type("rB[kk][nn]", s) << " = ldsB[k" << " + (nn*" << p_.vwidth*p_.ls1 << " + " << s << ")*" << lldb << "+ kk];" << std::endl;
|
||||
}
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
@@ -437,14 +434,14 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
for(unsigned int mm=0; mm < p_.mS; ++mm){
|
||||
string res_str, lhs_str, rhs_str;
|
||||
res_str = "rC[" + to_string(mm) + "][" + to_string(nn) + "]";
|
||||
if (p_.simd_width==1)
|
||||
if (p_.vwidth==1)
|
||||
lhs_str = "rA[" + to_string(kk) + "][" + to_string(mm) + "]";
|
||||
else
|
||||
lhs_str = access_vector_type("rA[" + to_string(kk) + "][" + to_string(mm/p_.simd_width) + "]", mm%p_.simd_width);
|
||||
if (p_.simd_width==1)
|
||||
lhs_str = access_vector_type("rA[" + to_string(kk) + "][" + to_string(mm/p_.vwidth) + "]", mm%p_.vwidth);
|
||||
if (p_.vwidth==1)
|
||||
rhs_str = "rB[" + to_string(kk) + "]["+to_string(nn)+"]";
|
||||
else
|
||||
rhs_str = access_vector_type("rB[" + to_string(kk) + "]["+to_string(nn/p_.simd_width)+"]", nn%p_.simd_width);
|
||||
rhs_str = access_vector_type("rB[" + to_string(kk) + "]["+to_string(nn/p_.vwidth)+"]", nn%p_.vwidth);
|
||||
stream << res_str << "= $MAD(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl;
|
||||
}
|
||||
|
||||
@@ -476,15 +473,15 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
if(A_trans_=='N' || B_trans_=='T')
|
||||
{
|
||||
stream << "int Ky = K - idT.y;" << std::endl;
|
||||
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_1)
|
||||
for(unsigned int k = 0; k < p_.kL; k += p_.lf1)
|
||||
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
|
||||
}
|
||||
|
||||
if(A_trans_=='T' || B_trans_=='N')
|
||||
{
|
||||
stream << "int Kx = K - idT.x;" << std::endl;
|
||||
for(unsigned int k = 0 ; k < p_.kL ; k += p_.local_fetch_0*p_.simd_width)
|
||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||
for(unsigned int k = 0 ; k < p_.kL ; k += p_.lf0*p_.vwidth)
|
||||
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
||||
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
|
||||
}
|
||||
fetch_to_lds(true);
|
||||
@@ -503,35 +500,35 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
stream << "N += ids.y;" << std::endl;
|
||||
|
||||
stream << "C += ids.x" << CSTRIDE1 << ";" << std::endl;
|
||||
stream << "C += ids.z*" << p_.simd_width << CSTRIDE1 << ";" << std::endl;
|
||||
stream << "C += ids.z*" << p_.vwidth << CSTRIDE1 << ";" << std::endl;
|
||||
stream << "C += ids.y*ldc;" << std::endl;
|
||||
stream << "C += ids.w*" << p_.simd_width << "*ldc;" << std::endl;
|
||||
stream << "C += ids.w*" << p_.vwidth << "*ldc;" << std::endl;
|
||||
if(has_depth)
|
||||
stream << "C += gidz*ldc*N;" << std::endl;
|
||||
|
||||
stream << "M -= ids.x;" << std::endl;
|
||||
stream << "M -= ids.z*" << p_.simd_width << ";" << std::endl;
|
||||
stream << "M -= ids.z*" << p_.vwidth << ";" << std::endl;
|
||||
|
||||
stream << "N -= ids.y;" << std::endl;
|
||||
stream << "N -= ids.w*" << p_.simd_width << ";" << std::endl;
|
||||
stream << "N -= ids.w*" << p_.vwidth << ";" << std::endl;
|
||||
|
||||
for(unsigned int n=0; n < p_.nS; ++n)
|
||||
{
|
||||
string Cj = to_string((n/p_.simd_width)*(p_.local_size_1*p_.simd_width) + n%p_.simd_width);
|
||||
string Cj = to_string((n/p_.vwidth)*(p_.ls1*p_.vwidth) + n%p_.vwidth);
|
||||
stream << "if(" << Cj << " >= N) return;" << std::endl;
|
||||
for(unsigned int m=0; m < p_.mS; ++m)
|
||||
stream << "rC[" << m << "][" << n << "] *= alpha;" << std::endl;
|
||||
for(unsigned int m=0; m < p_.mS; ++m)
|
||||
{
|
||||
string Ci = to_string((m/p_.simd_width)*(p_.local_size_0*p_.simd_width) + m%p_.simd_width);
|
||||
string Ci = to_string((m/p_.vwidth)*(p_.ls0*p_.vwidth) + m%p_.vwidth);
|
||||
stream << "if(" << Ci << "< M) ";
|
||||
if(has_depth)
|
||||
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "];" << std::endl;
|
||||
else
|
||||
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "] + ((beta != (" << sdtype << ")0)?(beta*" << "C[" << Ci << CSTRIDE1 << "]):0);" << std::endl;
|
||||
}
|
||||
if((n+1)%p_.simd_width==0){
|
||||
stream << "C += ldc*" << p_.local_size_1*p_.simd_width - p_.simd_width + 1 << ";" << std::endl;
|
||||
if((n+1)%p_.vwidth==0){
|
||||
stream << "C += ldc*" << p_.ls1*p_.vwidth - p_.vwidth + 1 << ";" << std::endl;
|
||||
}
|
||||
else{
|
||||
stream << "C += ldc;" << std::endl;
|
||||
@@ -599,8 +596,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
reduce_name += suffix;
|
||||
|
||||
driver::Kernel matrix_product(program, matrix_product_name.c_str());
|
||||
driver::NDRange local(p_.local_size_0, p_.local_size_1, 1);
|
||||
driver::NDRange global(align(align(M,p_.mS)/p_.mS, p_.local_size_0), align(align(N,p_.nS)/p_.nS, p_.local_size_1), p_.depth);
|
||||
driver::NDRange local(p_.ls0, p_.ls1, 1);
|
||||
driver::NDRange global(align(align(M,p_.mS)/p_.mS, p_.ls0), align(align(N,p_.nS)/p_.nS, p_.ls1), p_.depth);
|
||||
|
||||
unsigned int current_arg = 0;
|
||||
|
||||
@@ -651,8 +648,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
{
|
||||
unsigned int current_arg = 0;
|
||||
driver::Kernel reduce(program, reduce_name.c_str());
|
||||
driver::NDRange local(p_.local_size_0, p_.local_size_1);
|
||||
driver::NDRange global(align(M, p_.local_size_0), align(N, p_.local_size_1));
|
||||
driver::NDRange local(p_.ls0, p_.ls1);
|
||||
driver::NDRange global(align(M, p_.ls0), align(N, p_.ls1));
|
||||
reduce.setSizeArg(current_arg++, M);
|
||||
reduce.setSizeArg(current_arg++, N);
|
||||
reduce.setSizeArg(current_arg++, p_.depth);
|
||||
@@ -717,7 +714,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
matrix_product_nn::matrix_product_nn(unsigned int simd
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, fetch_type Afetch , fetch_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1) :
|
||||
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'N', 'N')
|
||||
{
|
||||
@@ -727,7 +724,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
matrix_product_tn::matrix_product_tn(unsigned int simd
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, fetch_type Afetch , fetch_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1) :
|
||||
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'T', 'N')
|
||||
{ }
|
||||
@@ -736,7 +733,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
matrix_product_nt::matrix_product_nt(unsigned int simd
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, fetch_type Afetch , fetch_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1) :
|
||||
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'N', 'T')
|
||||
{ }
|
||||
@@ -745,7 +742,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
matrix_product_tt::matrix_product_tt(unsigned int simd
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, fetch_type Afetch , fetch_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1) :
|
||||
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'T', 'T')
|
||||
{ }
|
||||
|
Reference in New Issue
Block a user