Code Generation/GEMM: Reverted to faster inner loop
This commit is contained in:
@@ -382,14 +382,15 @@ matrix_product_parameters::matrix_product_parameters(unsigned int vwidth
|
||||
stream << "ldsB = lB + ids.w*" << lldb*p_.vwidth << ";" << std::endl;
|
||||
|
||||
stream << "$LOCAL_BARRIER;" << std::endl;
|
||||
|
||||
std::string bound = last_iteration?"K":tools::to_string(p_.kL);
|
||||
size_t ks = last_iteration?1:p_.kS;
|
||||
stream << "//Inner loop" << std::endl;
|
||||
stream << "for(unsigned int k = 0; k < " << p_.kL << "; k+=" << p_.kS << "){" << std::endl;
|
||||
stream << "for(unsigned int k = 0; k < " << bound << "; k+=" << ks << "){" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
stream << "//Fetch A to registers" << std::endl;
|
||||
stream << "#pragma unroll" << std::endl;
|
||||
stream << "for(unsigned int kk = 0; kk < " << p_.kS << "; kk++)" << std::endl;
|
||||
stream << "for(unsigned int kk = 0; kk < " << ks << "; kk++)" << std::endl;
|
||||
stream << "#pragma unroll " << p_.mS/p_.vwidth << std::endl;
|
||||
stream << "for(unsigned int mm = 0; mm < " << p_.mS/p_.vwidth << "; mm++)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
@@ -409,8 +410,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int vwidth
|
||||
stream << "}" << std::endl;
|
||||
|
||||
stream << "//Fetch B to registers" << std::endl;
|
||||
stream << "#pragma unroll " << p_.kS << std::endl;
|
||||
stream << "for(unsigned int kk = 0; kk < " << p_.kS << "; kk++)" << std::endl;
|
||||
stream << "#pragma unroll " << ks << std::endl;
|
||||
stream << "for(unsigned int kk = 0; kk < " << ks << "; kk++)" << std::endl;
|
||||
stream << "#pragma unroll " << p_.nS/p_.vwidth << std::endl;
|
||||
stream << "for(unsigned int nn = 0; nn < " << p_.nS/p_.vwidth << "; nn++)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
@@ -429,22 +430,25 @@ matrix_product_parameters::matrix_product_parameters(unsigned int vwidth
|
||||
stream << "}" << std::endl;
|
||||
|
||||
stream << "//FMA computations" << std::endl;
|
||||
for(unsigned int kk=0 ; kk < p_.kS; ++kk)
|
||||
stream << "#pragma unroll" << std::endl;
|
||||
stream << "for(unsigned int kk = 0 ; kk < " << ks << "; ++kk){" << std::endl;
|
||||
stream.inc_tab();
|
||||
for(unsigned int nn=0; nn < p_.nS; ++nn)
|
||||
for(unsigned int mm=0; mm < p_.mS; ++mm){
|
||||
string res_str, lhs_str, rhs_str;
|
||||
res_str = "rC[" + to_string(mm) + "][" + to_string(nn) + "]";
|
||||
if (p_.vwidth==1)
|
||||
lhs_str = "rA[" + to_string(kk) + "][" + to_string(mm) + "]";
|
||||
lhs_str = "rA[kk][" + to_string(mm) + "]";
|
||||
else
|
||||
lhs_str = access_vector_type("rA[" + to_string(kk) + "][" + to_string(mm/p_.vwidth) + "]", mm%p_.vwidth);
|
||||
lhs_str = access_vector_type("rA[kk][" + to_string(mm/p_.vwidth) + "]", mm%p_.vwidth);
|
||||
if (p_.vwidth==1)
|
||||
rhs_str = "rB[" + to_string(kk) + "]["+to_string(nn)+"]";
|
||||
rhs_str = "rB[kk]["+to_string(nn)+"]";
|
||||
else
|
||||
rhs_str = access_vector_type("rB[" + to_string(kk) + "]["+to_string(nn/p_.vwidth)+"]", nn%p_.vwidth);
|
||||
rhs_str = access_vector_type("rB[kk]["+to_string(nn/p_.vwidth)+"]", nn%p_.vwidth);
|
||||
stream << res_str << "= $MAD(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl;
|
||||
}
|
||||
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
stream << "K -= " << p_.kL << ";" << std::endl;
|
||||
|
Reference in New Issue
Block a user