GEMM: More optimizations

This commit is contained in:
Philippe Tillet
2015-07-18 17:23:53 -04:00
parent 6ccf32904a
commit f4615446c5
2 changed files with 39 additions and 31 deletions

View File

@@ -318,27 +318,28 @@ void bench(ad::numeric_type dtype, std::string operation)
if(operation.substr(0,4)=="gemm")
{
std::vector<std::tuple<char, char, int_t, int_t, int_t> > MNKs;
MNKs.push_back(std::make_tuple('N','T',1536,1536,1536));
//AlexNet (Forward)
MNKs.push_back(std::make_tuple('N','N',3025,96,363));
MNKs.push_back(std::make_tuple('N','N',729,128,1200));
// MNKs.push_back(std::make_tuple('N','N',169,384,2304));
// MNKs.push_back(std::make_tuple('N','N',169,192,1728));
// MNKs.push_back(std::make_tuple('N','N',169,128,1728));
// //AlexNet (Backward)
// MNKs.push_back(std::make_tuple('T','N',1728,128,169));
// MNKs.push_back(std::make_tuple('T','N',1728,192,169));
// MNKs.push_back(std::make_tuple('T','N',2304,384,169));
// MNKs.push_back(std::make_tuple('T','N',1200,128,729));
// MNKs.push_back(std::make_tuple('T','N',363,96,3025));
MNKs.push_back(std::make_tuple('N','N',169,384,2304));
MNKs.push_back(std::make_tuple('N','N',169,192,1728));
MNKs.push_back(std::make_tuple('N','N',169,128,1728));
//AlexNet (Backward)
MNKs.push_back(std::make_tuple('T','N',1728,128,169));
MNKs.push_back(std::make_tuple('T','N',1728,192,169));
MNKs.push_back(std::make_tuple('T','N',2304,384,169));
MNKs.push_back(std::make_tuple('T','N',1200,128,729));
MNKs.push_back(std::make_tuple('T','N',363,96,3025));
// MNKs.push_back(std::make_tuple('N','T',169,1728,128));
// MNKs.push_back(std::make_tuple('N','T',169,1728,192));
// MNKs.push_back(std::make_tuple('N','T',169,2304,384));
// MNKs.push_back(std::make_tuple('N','T',729,1200,128));
MNKs.push_back(std::make_tuple('N','T',169,1728,128));
MNKs.push_back(std::make_tuple('N','T',169,1728,192));
MNKs.push_back(std::make_tuple('N','T',169,2304,384));
MNKs.push_back(std::make_tuple('N','T',729,1200,128));
// //Covariance (e.g., ICA)
// MNKs.push_back(std::make_tuple('N','N',64,64,32000));
// MNKs.push_back(std::make_tuple('N','N',1024,1024,32000));
//Covariance (e.g., ICA)
MNKs.push_back(std::make_tuple('N','N',64,64,32000));
MNKs.push_back(std::make_tuple('N','N',1024,1024,32000));
/*---------*/
/*--BLAS3--*/

View File

@@ -163,12 +163,15 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
unsigned int npB = p_.nL/(B_trans_=='T'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
stream << "__global " << sdtype << "* Ai[" << npA << "];" << std::endl;
stream << "__global " << sdtype << "* Bi[" << npB << "];" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* readA, * readB, * storeA, * storeB;" << std::endl;
stream << "long4 ids;" << std::endl;
stream << "int2 idT;" << std::endl;
stream << _size_t << " idt;" << std::endl;
if(has_depth)
stream << _size_t << " gidz, div, offz;" << std::endl;
stream << "int Ky, Kx;" << std::endl;
stream << "A += offa;" << std::endl;
stream << "B += offb;" << std::endl;
stream << "C += offc;" << std::endl;
@@ -228,15 +231,19 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
else
stream << "if(idT.y + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += " << i*p_.local_fetch_1 << "*ldb;" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* lAstore = lA + idT.y*" << llda << " + idT.x;" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* lBstore = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
stream << "//Outer loop" << std::endl;
stream << "while(K > 0){" << std::endl;
stream.inc_tab();
stream << LocalBarrier(backend) << ";" << std::endl;
if(A_trans_=='N' || B_trans_=='T')
stream << "Ky = K - idT.y;" << std::endl;
if(A_trans_=='T' || B_trans_=='N')
stream << "Kx = K - idT.x;" << std::endl;
stream << "//Fetch A to local memory" << std::endl;
if (A_trans_=='N')
{
@@ -246,8 +253,8 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k);
string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]");
to_load = "(idT.y + " + kk + "< K)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lAstore + " + to_string(k*llda+m)) << ";" << std::endl;
to_load = "(" + kk + "< Ky)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "storeA + " + to_string(k*llda+m)) << ";" << std::endl;
}
}
else
@@ -258,8 +265,8 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string mm = to_string(m/p_.local_fetch_1);
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]");
to_load = "(idT.x + " + kk + "< K)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lAstore + " + to_string(m*llda+k)) << ";" << std::endl;
to_load = "(" + kk + "< Kx)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "storeA + " + to_string(m*llda+k)) << ";" << std::endl;
}
}
@@ -272,8 +279,8 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*ldb]");
to_load = "(idT.y + " + kk + "< K)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lBstore + " + to_string(k*lldb+n)) << ";" << std::endl;
to_load = "(" + kk + "< Ky)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "storeB + " + to_string(k*lldb+n)) << ";" << std::endl;
}
}
else
@@ -284,21 +291,21 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string nn = to_string(n/p_.local_fetch_1);
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]");
to_load = "(idT.x + " + kk + "< K)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lBstore + " + to_string(n*lldb+k)) << ";" << std::endl;
to_load = "(" + kk + "< Kx)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "storeB + " + to_string(n*lldb+k)) << ";" << std::endl;
}
}
stream << LocalBarrier(backend) << ";" << std::endl;
if(A_trans_=='N')
stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + ids.z*" << p_.simd_width << ";" << std::endl;
stream << "readA = lA + ids.z*" << p_.simd_width << ";" << std::endl;
else
stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + ids.z*" << llda*p_.simd_width << ";" << std::endl;
stream << "readA = lA + ids.z*" << llda*p_.simd_width << ";" << std::endl;
if(B_trans_=='T')
stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + ids.w*" << p_.simd_width << ";" << std::endl;
stream << "readB = lB + ids.w*" << p_.simd_width << ";" << std::endl;
else
stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + ids.w*" << lldb*p_.simd_width << ";" << std::endl;
stream << "readB = lB + ids.w*" << lldb*p_.simd_width << ";" << std::endl;
stream << "//Inner loop" << std::endl;