GEMM: More optimizations

This commit is contained in:
Philippe Tillet
2015-07-18 17:23:53 -04:00
parent 6ccf32904a
commit f4615446c5
2 changed files with 39 additions and 31 deletions

View File

@@ -163,12 +163,15 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
unsigned int npB = p_.nL/(B_trans_=='T'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
stream << "__global " << sdtype << "* Ai[" << npA << "];" << std::endl;
stream << "__global " << sdtype << "* Bi[" << npB << "];" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* readA, * readB, * storeA, * storeB;" << std::endl;
stream << "long4 ids;" << std::endl;
stream << "int2 idT;" << std::endl;
stream << _size_t << " idt;" << std::endl;
if(has_depth)
stream << _size_t << " gidz, div, offz;" << std::endl;
stream << "int Ky, Kx;" << std::endl;
stream << "A += offa;" << std::endl;
stream << "B += offb;" << std::endl;
stream << "C += offc;" << std::endl;
@@ -228,15 +231,19 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
else
stream << "if(idT.y + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += " << i*p_.local_fetch_1 << "*ldb;" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* lAstore = lA + idT.y*" << llda << " + idT.x;" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* lBstore = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
stream << "//Outer loop" << std::endl;
stream << "while(K > 0){" << std::endl;
stream.inc_tab();
stream << LocalBarrier(backend) << ";" << std::endl;
if(A_trans_=='N' || B_trans_=='T')
stream << "Ky = K - idT.y;" << std::endl;
if(A_trans_=='T' || B_trans_=='N')
stream << "Kx = K - idT.x;" << std::endl;
stream << "//Fetch A to local memory" << std::endl;
if (A_trans_=='N')
{
@@ -246,8 +253,8 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k);
string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]");
to_load = "(idT.y + " + kk + "< K)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lAstore + " + to_string(k*llda+m)) << ";" << std::endl;
to_load = "(" + kk + "< Ky)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "storeA + " + to_string(k*llda+m)) << ";" << std::endl;
}
}
else
@@ -258,8 +265,8 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string mm = to_string(m/p_.local_fetch_1);
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]");
to_load = "(idT.x + " + kk + "< K)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lAstore + " + to_string(m*llda+k)) << ";" << std::endl;
to_load = "(" + kk + "< Kx)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "storeA + " + to_string(m*llda+k)) << ";" << std::endl;
}
}
@@ -272,8 +279,8 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*ldb]");
to_load = "(idT.y + " + kk + "< K)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lBstore + " + to_string(k*lldb+n)) << ";" << std::endl;
to_load = "(" + kk + "< Ky)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "storeB + " + to_string(k*lldb+n)) << ";" << std::endl;
}
}
else
@@ -284,21 +291,21 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string nn = to_string(n/p_.local_fetch_1);
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]");
to_load = "(idT.x + " + kk + "< K)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lBstore + " + to_string(n*lldb+k)) << ";" << std::endl;
to_load = "(" + kk + "< Kx)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "storeB + " + to_string(n*lldb+k)) << ";" << std::endl;
}
}
stream << LocalBarrier(backend) << ";" << std::endl;
if(A_trans_=='N')
stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + ids.z*" << p_.simd_width << ";" << std::endl;
stream << "readA = lA + ids.z*" << p_.simd_width << ";" << std::endl;
else
stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + ids.z*" << llda*p_.simd_width << ";" << std::endl;
stream << "readA = lA + ids.z*" << llda*p_.simd_width << ";" << std::endl;
if(B_trans_=='T')
stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + ids.w*" << p_.simd_width << ";" << std::endl;
stream << "readB = lB + ids.w*" << p_.simd_width << ";" << std::endl;
else
stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + ids.w*" << lldb*p_.simd_width << ";" << std::endl;
stream << "readB = lB + ids.w*" << lldb*p_.simd_width << ";" << std::endl;
stream << "//Inner loop" << std::endl;