GEMM: Better update of Kx, Ky
This commit is contained in:
@@ -121,6 +121,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
std::string sdtype = numeric_type_to_string(dtype);
|
||||
std::string vdtype = append_width(sdtype, p_.simd_width);
|
||||
std::string _size_t = size_type(device);
|
||||
std::string vint = append_width("int", p_.simd_width);
|
||||
|
||||
//////////////////
|
||||
/// DECLARATIONS
|
||||
@@ -259,17 +260,16 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
|
||||
stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
|
||||
|
||||
stream << "//Outer loop" << std::endl;
|
||||
stream << "while(K > 0){" << std::endl;
|
||||
stream.inc_tab();
|
||||
stream << LocalBarrier(backend) << ";" << std::endl;
|
||||
|
||||
if(A_trans_=='N' || B_trans_=='T')
|
||||
stream << "Ky = K - idT.y;" << std::endl;
|
||||
if(A_trans_=='T' || B_trans_=='N')
|
||||
stream << "Kx = K - idT.x;" << std::endl;
|
||||
|
||||
std::string vint = append_width("int", p_.simd_width);
|
||||
stream << "//Outer loop" << std::endl;
|
||||
stream << "while(K > 0){" << std::endl;
|
||||
stream.inc_tab();
|
||||
stream << LocalBarrier(backend) << ";" << std::endl;
|
||||
|
||||
|
||||
if(A_trans_=='N' || B_trans_=='T')
|
||||
{
|
||||
@@ -419,6 +419,13 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
|
||||
|
||||
stream << "K -= " << p_.kL << ";" << std::endl;
|
||||
if(A_trans_=='N' || B_trans_=='T')
|
||||
stream << "Ky -= " << p_.kL << ";" << std::endl;
|
||||
if(A_trans_=='T' || B_trans_=='N')
|
||||
stream << "Kx -= " << p_.kL << ";" << std::endl;
|
||||
|
||||
//Increment A pointers to global memory
|
||||
if (A_trans_=='N')
|
||||
for(unsigned int i = 0 ; i < npA ; ++i)
|
||||
@@ -435,7 +442,6 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
for(unsigned int i = 0 ; i < npB ; ++i)
|
||||
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
|
||||
|
||||
stream << "K -= " << p_.kL << ";" << std::endl;
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
|
||||
|
Reference in New Issue
Block a user