diff --git a/lib/jit/generation/matrix_product.cpp b/lib/jit/generation/matrix_product.cpp index 311de190f..cbb5bc852 100644 --- a/lib/jit/generation/matrix_product.cpp +++ b/lib/jit/generation/matrix_product.cpp @@ -382,14 +382,15 @@ matrix_product_parameters::matrix_product_parameters(unsigned int vwidth stream << "ldsB = lB + ids.w*" << lldb*p_.vwidth << ";" << std::endl; stream << "$LOCAL_BARRIER;" << std::endl; - + std::string bound = last_iteration?"K":tools::to_string(p_.kL); + size_t ks = last_iteration?1:p_.kS; stream << "//Inner loop" << std::endl; - stream << "for(unsigned int k = 0; k < " << p_.kL << "; k+=" << p_.kS << "){" << std::endl; + stream << "for(unsigned int k = 0; k < " << bound << "; k+=" << ks << "){" << std::endl; stream.inc_tab(); stream << "//Fetch A to registers" << std::endl; stream << "#pragma unroll" << std::endl; - stream << "for(unsigned int kk = 0; kk < " << p_.kS << "; kk++)" << std::endl; + stream << "for(unsigned int kk = 0; kk < " << ks << "; kk++)" << std::endl; stream << "#pragma unroll " << p_.mS/p_.vwidth << std::endl; stream << "for(unsigned int mm = 0; mm < " << p_.mS/p_.vwidth << "; mm++)" << std::endl; stream << "{" << std::endl; @@ -409,8 +410,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int vwidth stream << "}" << std::endl; stream << "//Fetch B to registers" << std::endl; - stream << "#pragma unroll " << p_.kS << std::endl; - stream << "for(unsigned int kk = 0; kk < " << p_.kS << "; kk++)" << std::endl; + stream << "#pragma unroll " << ks << std::endl; + stream << "for(unsigned int kk = 0; kk < " << ks << "; kk++)" << std::endl; stream << "#pragma unroll " << p_.nS/p_.vwidth << std::endl; stream << "for(unsigned int nn = 0; nn < " << p_.nS/p_.vwidth << "; nn++)" << std::endl; stream << "{" << std::endl; @@ -429,22 +430,25 @@ matrix_product_parameters::matrix_product_parameters(unsigned int vwidth stream << "}" << std::endl; stream << "//FMA computations" << std::endl; - for(unsigned int kk=0 ; kk < p_.kS; ++kk) + stream << "#pragma unroll" << std::endl; + stream << "for(unsigned int kk = 0 ; kk < " << ks << "; ++kk){" << std::endl; + stream.inc_tab(); for(unsigned int nn=0; nn < p_.nS; ++nn) for(unsigned int mm=0; mm < p_.mS; ++mm){ string res_str, lhs_str, rhs_str; res_str = "rC[" + to_string(mm) + "][" + to_string(nn) + "]"; if (p_.vwidth==1) - lhs_str = "rA[" + to_string(kk) + "][" + to_string(mm) + "]"; + lhs_str = "rA[kk][" + to_string(mm) + "]"; else - lhs_str = access_vector_type("rA[" + to_string(kk) + "][" + to_string(mm/p_.vwidth) + "]", mm%p_.vwidth); + lhs_str = access_vector_type("rA[kk][" + to_string(mm/p_.vwidth) + "]", mm%p_.vwidth); if (p_.vwidth==1) - rhs_str = "rB[" + to_string(kk) + "]["+to_string(nn)+"]"; + rhs_str = "rB[kk]["+to_string(nn)+"]"; else - rhs_str = access_vector_type("rB[" + to_string(kk) + "]["+to_string(nn/p_.vwidth)+"]", nn%p_.vwidth); + rhs_str = access_vector_type("rB[kk]["+to_string(nn/p_.vwidth)+"]", nn%p_.vwidth); stream << res_str << "= $MAD(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl; } - + stream.dec_tab(); + stream << "}" << std::endl; stream.dec_tab(); stream << "}" << std::endl; stream << "K -= " << p_.kL << ";" << std::endl;