GEMM: Better update of Kx, Ky
This commit is contained in:
@@ -121,6 +121,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
std::string sdtype = numeric_type_to_string(dtype);
|
std::string sdtype = numeric_type_to_string(dtype);
|
||||||
std::string vdtype = append_width(sdtype, p_.simd_width);
|
std::string vdtype = append_width(sdtype, p_.simd_width);
|
||||||
std::string _size_t = size_type(device);
|
std::string _size_t = size_type(device);
|
||||||
|
std::string vint = append_width("int", p_.simd_width);
|
||||||
|
|
||||||
//////////////////
|
//////////////////
|
||||||
/// DECLARATIONS
|
/// DECLARATIONS
|
||||||
@@ -259,17 +260,16 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
|
stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
|
||||||
stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
|
stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
|
||||||
|
|
||||||
stream << "//Outer loop" << std::endl;
|
|
||||||
stream << "while(K > 0){" << std::endl;
|
|
||||||
stream.inc_tab();
|
|
||||||
stream << LocalBarrier(backend) << ";" << std::endl;
|
|
||||||
|
|
||||||
if(A_trans_=='N' || B_trans_=='T')
|
if(A_trans_=='N' || B_trans_=='T')
|
||||||
stream << "Ky = K - idT.y;" << std::endl;
|
stream << "Ky = K - idT.y;" << std::endl;
|
||||||
if(A_trans_=='T' || B_trans_=='N')
|
if(A_trans_=='T' || B_trans_=='N')
|
||||||
stream << "Kx = K - idT.x;" << std::endl;
|
stream << "Kx = K - idT.x;" << std::endl;
|
||||||
|
|
||||||
std::string vint = append_width("int", p_.simd_width);
|
stream << "//Outer loop" << std::endl;
|
||||||
|
stream << "while(K > 0){" << std::endl;
|
||||||
|
stream.inc_tab();
|
||||||
|
stream << LocalBarrier(backend) << ";" << std::endl;
|
||||||
|
|
||||||
|
|
||||||
if(A_trans_=='N' || B_trans_=='T')
|
if(A_trans_=='N' || B_trans_=='T')
|
||||||
{
|
{
|
||||||
@@ -419,6 +419,13 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
|
|
||||||
|
|
||||||
|
stream << "K -= " << p_.kL << ";" << std::endl;
|
||||||
|
if(A_trans_=='N' || B_trans_=='T')
|
||||||
|
stream << "Ky -= " << p_.kL << ";" << std::endl;
|
||||||
|
if(A_trans_=='T' || B_trans_=='N')
|
||||||
|
stream << "Kx -= " << p_.kL << ";" << std::endl;
|
||||||
|
|
||||||
//Increment A pointers to global memory
|
//Increment A pointers to global memory
|
||||||
if (A_trans_=='N')
|
if (A_trans_=='N')
|
||||||
for(unsigned int i = 0 ; i < npA ; ++i)
|
for(unsigned int i = 0 ; i < npA ; ++i)
|
||||||
@@ -435,7 +442,6 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
for(unsigned int i = 0 ; i < npB ; ++i)
|
for(unsigned int i = 0 ; i < npB ; ++i)
|
||||||
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
|
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
|
||||||
|
|
||||||
stream << "K -= " << p_.kL << ";" << std::endl;
|
|
||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user