GEMM: More optimizations
This commit is contained in:
@@ -318,27 +318,28 @@ void bench(ad::numeric_type dtype, std::string operation)
|
|||||||
if(operation.substr(0,4)=="gemm")
|
if(operation.substr(0,4)=="gemm")
|
||||||
{
|
{
|
||||||
std::vector<std::tuple<char, char, int_t, int_t, int_t> > MNKs;
|
std::vector<std::tuple<char, char, int_t, int_t, int_t> > MNKs;
|
||||||
|
MNKs.push_back(std::make_tuple('N','T',1536,1536,1536));
|
||||||
//AlexNet (Forward)
|
//AlexNet (Forward)
|
||||||
MNKs.push_back(std::make_tuple('N','N',3025,96,363));
|
MNKs.push_back(std::make_tuple('N','N',3025,96,363));
|
||||||
MNKs.push_back(std::make_tuple('N','N',729,128,1200));
|
MNKs.push_back(std::make_tuple('N','N',729,128,1200));
|
||||||
// MNKs.push_back(std::make_tuple('N','N',169,384,2304));
|
MNKs.push_back(std::make_tuple('N','N',169,384,2304));
|
||||||
// MNKs.push_back(std::make_tuple('N','N',169,192,1728));
|
MNKs.push_back(std::make_tuple('N','N',169,192,1728));
|
||||||
// MNKs.push_back(std::make_tuple('N','N',169,128,1728));
|
MNKs.push_back(std::make_tuple('N','N',169,128,1728));
|
||||||
// //AlexNet (Backward)
|
//AlexNet (Backward)
|
||||||
// MNKs.push_back(std::make_tuple('T','N',1728,128,169));
|
MNKs.push_back(std::make_tuple('T','N',1728,128,169));
|
||||||
// MNKs.push_back(std::make_tuple('T','N',1728,192,169));
|
MNKs.push_back(std::make_tuple('T','N',1728,192,169));
|
||||||
// MNKs.push_back(std::make_tuple('T','N',2304,384,169));
|
MNKs.push_back(std::make_tuple('T','N',2304,384,169));
|
||||||
// MNKs.push_back(std::make_tuple('T','N',1200,128,729));
|
MNKs.push_back(std::make_tuple('T','N',1200,128,729));
|
||||||
// MNKs.push_back(std::make_tuple('T','N',363,96,3025));
|
MNKs.push_back(std::make_tuple('T','N',363,96,3025));
|
||||||
|
|
||||||
// MNKs.push_back(std::make_tuple('N','T',169,1728,128));
|
MNKs.push_back(std::make_tuple('N','T',169,1728,128));
|
||||||
// MNKs.push_back(std::make_tuple('N','T',169,1728,192));
|
MNKs.push_back(std::make_tuple('N','T',169,1728,192));
|
||||||
// MNKs.push_back(std::make_tuple('N','T',169,2304,384));
|
MNKs.push_back(std::make_tuple('N','T',169,2304,384));
|
||||||
// MNKs.push_back(std::make_tuple('N','T',729,1200,128));
|
MNKs.push_back(std::make_tuple('N','T',729,1200,128));
|
||||||
|
|
||||||
// //Covariance (e.g., ICA)
|
//Covariance (e.g., ICA)
|
||||||
// MNKs.push_back(std::make_tuple('N','N',64,64,32000));
|
MNKs.push_back(std::make_tuple('N','N',64,64,32000));
|
||||||
// MNKs.push_back(std::make_tuple('N','N',1024,1024,32000));
|
MNKs.push_back(std::make_tuple('N','N',1024,1024,32000));
|
||||||
|
|
||||||
/*---------*/
|
/*---------*/
|
||||||
/*--BLAS3--*/
|
/*--BLAS3--*/
|
||||||
|
@@ -163,12 +163,15 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
unsigned int npB = p_.nL/(B_trans_=='T'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
|
unsigned int npB = p_.nL/(B_trans_=='T'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
|
||||||
stream << "__global " << sdtype << "* Ai[" << npA << "];" << std::endl;
|
stream << "__global " << sdtype << "* Ai[" << npA << "];" << std::endl;
|
||||||
stream << "__global " << sdtype << "* Bi[" << npB << "];" << std::endl;
|
stream << "__global " << sdtype << "* Bi[" << npB << "];" << std::endl;
|
||||||
|
stream << LocalPtr(backend) << " " << sdtype << "* readA, * readB, * storeA, * storeB;" << std::endl;
|
||||||
|
|
||||||
stream << "long4 ids;" << std::endl;
|
stream << "long4 ids;" << std::endl;
|
||||||
stream << "int2 idT;" << std::endl;
|
stream << "int2 idT;" << std::endl;
|
||||||
stream << _size_t << " idt;" << std::endl;
|
stream << _size_t << " idt;" << std::endl;
|
||||||
if(has_depth)
|
if(has_depth)
|
||||||
stream << _size_t << " gidz, div, offz;" << std::endl;
|
stream << _size_t << " gidz, div, offz;" << std::endl;
|
||||||
|
stream << "int Ky, Kx;" << std::endl;
|
||||||
|
|
||||||
stream << "A += offa;" << std::endl;
|
stream << "A += offa;" << std::endl;
|
||||||
stream << "B += offb;" << std::endl;
|
stream << "B += offb;" << std::endl;
|
||||||
stream << "C += offc;" << std::endl;
|
stream << "C += offc;" << std::endl;
|
||||||
@@ -228,15 +231,19 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
else
|
else
|
||||||
stream << "if(idT.y + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += " << i*p_.local_fetch_1 << "*ldb;" << std::endl;
|
stream << "if(idT.y + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += " << i*p_.local_fetch_1 << "*ldb;" << std::endl;
|
||||||
|
|
||||||
stream << LocalPtr(backend) << " " << sdtype << "* lAstore = lA + idT.y*" << llda << " + idT.x;" << std::endl;
|
stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
|
||||||
stream << LocalPtr(backend) << " " << sdtype << "* lBstore = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
|
stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
|
||||||
|
|
||||||
|
|
||||||
stream << "//Outer loop" << std::endl;
|
stream << "//Outer loop" << std::endl;
|
||||||
stream << "while(K > 0){" << std::endl;
|
stream << "while(K > 0){" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
stream << LocalBarrier(backend) << ";" << std::endl;
|
stream << LocalBarrier(backend) << ";" << std::endl;
|
||||||
|
|
||||||
|
if(A_trans_=='N' || B_trans_=='T')
|
||||||
|
stream << "Ky = K - idT.y;" << std::endl;
|
||||||
|
if(A_trans_=='T' || B_trans_=='N')
|
||||||
|
stream << "Kx = K - idT.x;" << std::endl;
|
||||||
|
|
||||||
stream << "//Fetch A to local memory" << std::endl;
|
stream << "//Fetch A to local memory" << std::endl;
|
||||||
if (A_trans_=='N')
|
if (A_trans_=='N')
|
||||||
{
|
{
|
||||||
@@ -246,8 +253,8 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0));
|
std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0));
|
||||||
std::string kk = to_string(k);
|
std::string kk = to_string(k);
|
||||||
string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]");
|
string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]");
|
||||||
to_load = "(idT.y + " + kk + "< K)?" + to_load + ":0";
|
to_load = "(" + kk + "< Ky)?" + to_load + ":0";
|
||||||
stream << VSTORE(to_load, "0", "lAstore + " + to_string(k*llda+m)) << ";" << std::endl;
|
stream << VSTORE(to_load, "0", "storeA + " + to_string(k*llda+m)) << ";" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@@ -258,8 +265,8 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
std::string mm = to_string(m/p_.local_fetch_1);
|
std::string mm = to_string(m/p_.local_fetch_1);
|
||||||
std::string kk = to_string(k);
|
std::string kk = to_string(k);
|
||||||
string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]");
|
string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]");
|
||||||
to_load = "(idT.x + " + kk + "< K)?" + to_load + ":0";
|
to_load = "(" + kk + "< Kx)?" + to_load + ":0";
|
||||||
stream << VSTORE(to_load, "0", "lAstore + " + to_string(m*llda+k)) << ";" << std::endl;
|
stream << VSTORE(to_load, "0", "storeA + " + to_string(m*llda+k)) << ";" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -272,8 +279,8 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0));
|
std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0));
|
||||||
std::string kk = to_string(k);
|
std::string kk = to_string(k);
|
||||||
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*ldb]");
|
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*ldb]");
|
||||||
to_load = "(idT.y + " + kk + "< K)?" + to_load + ":0";
|
to_load = "(" + kk + "< Ky)?" + to_load + ":0";
|
||||||
stream << VSTORE(to_load, "0", "lBstore + " + to_string(k*lldb+n)) << ";" << std::endl;
|
stream << VSTORE(to_load, "0", "storeB + " + to_string(k*lldb+n)) << ";" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@@ -284,21 +291,21 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
std::string nn = to_string(n/p_.local_fetch_1);
|
std::string nn = to_string(n/p_.local_fetch_1);
|
||||||
std::string kk = to_string(k);
|
std::string kk = to_string(k);
|
||||||
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]");
|
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]");
|
||||||
to_load = "(idT.x + " + kk + "< K)?" + to_load + ":0";
|
to_load = "(" + kk + "< Kx)?" + to_load + ":0";
|
||||||
stream << VSTORE(to_load, "0", "lBstore + " + to_string(n*lldb+k)) << ";" << std::endl;
|
stream << VSTORE(to_load, "0", "storeB + " + to_string(n*lldb+k)) << ";" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stream << LocalBarrier(backend) << ";" << std::endl;
|
stream << LocalBarrier(backend) << ";" << std::endl;
|
||||||
|
|
||||||
if(A_trans_=='N')
|
if(A_trans_=='N')
|
||||||
stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + ids.z*" << p_.simd_width << ";" << std::endl;
|
stream << "readA = lA + ids.z*" << p_.simd_width << ";" << std::endl;
|
||||||
else
|
else
|
||||||
stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + ids.z*" << llda*p_.simd_width << ";" << std::endl;
|
stream << "readA = lA + ids.z*" << llda*p_.simd_width << ";" << std::endl;
|
||||||
|
|
||||||
if(B_trans_=='T')
|
if(B_trans_=='T')
|
||||||
stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + ids.w*" << p_.simd_width << ";" << std::endl;
|
stream << "readB = lB + ids.w*" << p_.simd_width << ";" << std::endl;
|
||||||
else
|
else
|
||||||
stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + ids.w*" << lldb*p_.simd_width << ";" << std::endl;
|
stream << "readB = lB + ids.w*" << lldb*p_.simd_width << ";" << std::endl;
|
||||||
|
|
||||||
|
|
||||||
stream << "//Inner loop" << std::endl;
|
stream << "//Inner loop" << std::endl;
|
||||||
|
Reference in New Issue
Block a user