GEMM: Better update of Kx, Ky

This commit is contained in:
Philippe Tillet
2015-07-21 14:35:30 -04:00
parent 33bd3a77fc
commit 79f833ba65

View File

@@ -121,6 +121,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string sdtype = numeric_type_to_string(dtype); std::string sdtype = numeric_type_to_string(dtype);
std::string vdtype = append_width(sdtype, p_.simd_width); std::string vdtype = append_width(sdtype, p_.simd_width);
std::string _size_t = size_type(device); std::string _size_t = size_type(device);
std::string vint = append_width("int", p_.simd_width);
////////////////// //////////////////
/// DECLARATIONS /// DECLARATIONS
@@ -259,17 +260,16 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl; stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl; stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
stream << "//Outer loop" << std::endl;
stream << "while(K > 0){" << std::endl;
stream.inc_tab();
stream << LocalBarrier(backend) << ";" << std::endl;
if(A_trans_=='N' || B_trans_=='T') if(A_trans_=='N' || B_trans_=='T')
stream << "Ky = K - idT.y;" << std::endl; stream << "Ky = K - idT.y;" << std::endl;
if(A_trans_=='T' || B_trans_=='N') if(A_trans_=='T' || B_trans_=='N')
stream << "Kx = K - idT.x;" << std::endl; stream << "Kx = K - idT.x;" << std::endl;
std::string vint = append_width("int", p_.simd_width); stream << "//Outer loop" << std::endl;
stream << "while(K > 0){" << std::endl;
stream.inc_tab();
stream << LocalBarrier(backend) << ";" << std::endl;
if(A_trans_=='N' || B_trans_=='T') if(A_trans_=='N' || B_trans_=='T')
{ {
@@ -419,6 +419,13 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
stream << "K -= " << p_.kL << ";" << std::endl;
if(A_trans_=='N' || B_trans_=='T')
stream << "Ky -= " << p_.kL << ";" << std::endl;
if(A_trans_=='T' || B_trans_=='N')
stream << "Kx -= " << p_.kL << ";" << std::endl;
//Increment A pointers to global memory //Increment A pointers to global memory
if (A_trans_=='N') if (A_trans_=='N')
for(unsigned int i = 0 ; i < npA ; ++i) for(unsigned int i = 0 ; i < npA ; ++i)
@@ -435,7 +442,6 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
for(unsigned int i = 0 ; i < npB ; ++i) for(unsigned int i = 0 ; i < npB ; ++i)
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl; stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
stream << "K -= " << p_.kL << ";" << std::endl;
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;