Code Generation/GEMM: Reverted to faster inner loop

This commit is contained in:
Philippe Tillet
2016-09-27 23:44:22 -04:00
parent bcf760967a
commit 29b3a576df

View File

@@ -382,14 +382,15 @@ matrix_product_parameters::matrix_product_parameters(unsigned int vwidth
stream << "ldsB = lB + ids.w*" << lldb*p_.vwidth << ";" << std::endl;
stream << "$LOCAL_BARRIER;" << std::endl;
std::string bound = last_iteration?"K":tools::to_string(p_.kL);
size_t ks = last_iteration?1:p_.kS;
stream << "//Inner loop" << std::endl;
stream << "for(unsigned int k = 0; k < " << p_.kL << "; k+=" << p_.kS << "){" << std::endl;
stream << "for(unsigned int k = 0; k < " << bound << "; k+=" << ks << "){" << std::endl;
stream.inc_tab();
stream << "//Fetch A to registers" << std::endl;
stream << "#pragma unroll" << std::endl;
stream << "for(unsigned int kk = 0; kk < " << p_.kS << "; kk++)" << std::endl;
stream << "for(unsigned int kk = 0; kk < " << ks << "; kk++)" << std::endl;
stream << "#pragma unroll " << p_.mS/p_.vwidth << std::endl;
stream << "for(unsigned int mm = 0; mm < " << p_.mS/p_.vwidth << "; mm++)" << std::endl;
stream << "{" << std::endl;
@@ -409,8 +410,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int vwidth
stream << "}" << std::endl;
stream << "//Fetch B to registers" << std::endl;
stream << "#pragma unroll " << p_.kS << std::endl;
stream << "for(unsigned int kk = 0; kk < " << p_.kS << "; kk++)" << std::endl;
stream << "#pragma unroll " << ks << std::endl;
stream << "for(unsigned int kk = 0; kk < " << ks << "; kk++)" << std::endl;
stream << "#pragma unroll " << p_.nS/p_.vwidth << std::endl;
stream << "for(unsigned int nn = 0; nn < " << p_.nS/p_.vwidth << "; nn++)" << std::endl;
stream << "{" << std::endl;
@@ -429,22 +430,25 @@ matrix_product_parameters::matrix_product_parameters(unsigned int vwidth
stream << "}" << std::endl;
stream << "//FMA computations" << std::endl;
for(unsigned int kk=0 ; kk < p_.kS; ++kk)
stream << "#pragma unroll" << std::endl;
stream << "for(unsigned int kk = 0 ; kk < " << ks << "; ++kk){" << std::endl;
stream.inc_tab();
for(unsigned int nn=0; nn < p_.nS; ++nn)
for(unsigned int mm=0; mm < p_.mS; ++mm){
string res_str, lhs_str, rhs_str;
res_str = "rC[" + to_string(mm) + "][" + to_string(nn) + "]";
if (p_.vwidth==1)
lhs_str = "rA[" + to_string(kk) + "][" + to_string(mm) + "]";
lhs_str = "rA[kk][" + to_string(mm) + "]";
else
lhs_str = access_vector_type("rA[" + to_string(kk) + "][" + to_string(mm/p_.vwidth) + "]", mm%p_.vwidth);
lhs_str = access_vector_type("rA[kk][" + to_string(mm/p_.vwidth) + "]", mm%p_.vwidth);
if (p_.vwidth==1)
rhs_str = "rB[" + to_string(kk) + "]["+to_string(nn)+"]";
rhs_str = "rB[kk]["+to_string(nn)+"]";
else
rhs_str = access_vector_type("rB[" + to_string(kk) + "]["+to_string(nn/p_.vwidth)+"]", nn%p_.vwidth);
rhs_str = access_vector_type("rB[kk]["+to_string(nn/p_.vwidth)+"]", nn%p_.vwidth);
stream << res_str << "= $MAD(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl;
}
stream.dec_tab();
stream << "}" << std::endl;
stream.dec_tab();
stream << "}" << std::endl;
stream << "K -= " << p_.kL << ";" << std::endl;