GEMM: Removed offx, offy

This commit is contained in:
Philippe Tillet
2015-07-18 10:24:44 -07:00
parent f4615446c5
commit 7fdb8c0457

View File

@@ -51,8 +51,8 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
if(p_.A_fetching_policy!=FETCH_FROM_LOCAL || p_.B_fetching_policy!=FETCH_FROM_LOCAL)
throw operation_not_supported_exception("Only local memory is supported for GEMM");
// if(p_.depth > 1 && M*N*p_.depth > 2e6)
// throw operation_not_supported_exception("This would necessitate a temporary larger than 1MB");
if(p_.depth > 1 && M*N*p_.depth > 2e6)
throw operation_not_supported_exception("This would necessitate a temporary larger than 1MB");
if ((p_.mS % p_.simd_width) > 0 || (p_.nS % p_.simd_width) > 0)
return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE;
@@ -149,22 +149,22 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream.inc_tab();
///Declare
stream << "//Declarations" << std::endl;
//Block
stream << sdtype << " rC[" << p_.mS << "][" << p_.nS << "] = {{0}};" << std::endl;
stream << vdtype << " rA[" << p_.kS << "][" << p_.mS/p_.simd_width << "];" << std::endl;
stream << vdtype << " rB[" << p_.kS << "][" << p_.nS/p_.simd_width << "];" << std::endl;
//Pointers
size_t llda = (A_trans_=='N')?p_.mL:p_.kL;
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL;
stream << Local(backend) << " " << sdtype << " lA[" << p_.kL*p_.mL << "];" << std::endl;
stream << Local(backend) << " " << sdtype << " lB[" << p_.kL*p_.nL << "];" << std::endl;
stream << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* readA, * readB, * storeA, * storeB;" << std::endl;
unsigned int npA = p_.mL/(A_trans_=='N'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
unsigned int npB = p_.nL/(B_trans_=='T'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
stream << "__global " << sdtype << "* Ai[" << npA << "];" << std::endl;
stream << "__global " << sdtype << "* Bi[" << npB << "];" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* readA, * readB, * storeA, * storeB;" << std::endl;
stream << Global(backend) << " " << sdtype << "* Ai[" << npA << "];" << std::endl;
stream << Global(backend) << " " << sdtype << "* Bi[" << npB << "];" << std::endl;
//Helpers
stream << "long4 ids;" << std::endl;
stream << "int2 idT;" << std::endl;
stream << _size_t << " idt;" << std::endl;
@@ -172,15 +172,13 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << _size_t << " gidz, div, offz;" << std::endl;
stream << "int Ky, Kx;" << std::endl;
stream << "A += offa;" << std::endl;
stream << "B += offb;" << std::endl;
stream << "C += offc;" << std::endl;
stream << std::endl;
stream << "//Helpers" << std::endl;
stream << "ids.x = " << GroupIdx0(backend) << ";" << std::endl;
stream << "ids.y = " << GroupIdx1(backend) << ";" << std::endl;
stream << "ids.z = " << LocalIdx0(backend) << ";" << std::endl;
stream << "ids.w = " << LocalIdx1(backend) << ";" << std::endl;
if(has_depth)
{
stream << "gidz = " << GroupIdx2(backend) << ";" << std::endl;
@@ -188,54 +186,51 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "offz = div*gidz;" << std::endl;
stream << "K = min(K - div*gidz, div);" << std::endl;
}
stream << "idt = " << p_.local_size_0 << "*ids.w + ids.z;" << std::endl;
stream << "idT.y = idt/" << p_.local_fetch_0 << ";" << std::endl;
stream << "idT.x = idt - " << p_.local_fetch_0 << "*idT.y;" << std::endl;
stream << "ids.x *= " << p_.mL << ";" << std::endl;
stream << "ids.y *= " << p_.nL << ";" << std::endl;
stream << "idT.x *= " << p_.simd_width << ";" << std::endl;
stream << "M -= ids.x;" << std::endl;
stream << "N -= ids.y;" << std::endl;
stream << std::endl;
stream << "// Offset A" << std::endl;
stream << "A += offa;" << std::endl;
if (A_trans_=='N')
stream << "A += (idT.x + ids.x)" << ASTRIDE1 << " + idT.y*lda" << (has_depth?"+ offz*lda":"") << ";" << std::endl;
else
stream << "A += idT.x" << ASTRIDE1 << " + idT.y*lda + ids.x*lda" << (has_depth?"+ offz":"") << ";" << std::endl;
if(B_trans_=='T')
stream << "B += (idT.x + ids.y)" << BSTRIDE1 << " + idT.y*ldb" << (has_depth?"+ offz*ldb":"") << ";" << std::endl;
else
stream << "B += idT.x" << BSTRIDE1 << " + idT.y*ldb + ids.y*ldb" << (has_depth?"+ offz":"") << ";" << std::endl;
stream << "#pragma unroll" << std::endl;
stream << "for(int i = 0 ; i < " << npA << " ; ++i) " << std::endl;
stream << "Ai[i] = A;" << std::endl;
stream << "#pragma unroll" << std::endl;
stream << "for(int i = 0 ; i < " << npB << " ; ++i)" << std::endl;
stream << "Bi[i] = B;" << std::endl;
for(int i = 0 ; i < npA ; ++i)
stream << "Ai[" << i << "] = A;" << std::endl;
for(unsigned int i = 0 ; i < npA ; i++ )
if (A_trans_=='N')
stream << "if(idT.x + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < M) Ai[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << ASTRIDE1 << ";" << std::endl;
else
stream << "if(idT.y + " << i << "*" << p_.local_fetch_1 << " < M) Ai[" << i << "] += " << i*p_.local_fetch_1 << "*lda;" << std::endl;
stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
stream << std::endl;
stream << "// Offset B" << std::endl;
stream << "B += offb;" << std::endl;
if(B_trans_=='T')
stream << "B += (idT.x + ids.y)" << BSTRIDE1 << " + idT.y*ldb" << (has_depth?"+ offz*ldb":"") << ";" << std::endl;
else
stream << "B += idT.x" << BSTRIDE1 << " + idT.y*ldb + ids.y*ldb" << (has_depth?"+ offz":"") << ";" << std::endl;
for(int i = 0 ; i < npB ; ++i)
stream << "Bi[" << i << "] = B;" << std::endl;
for(unsigned int i = 0 ; i < npB ; i++ )
if (B_trans_=='T')
stream << "if(idT.x + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << BSTRIDE1 << ";" << std::endl;
else
stream << "if(idT.y + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += " << i*p_.local_fetch_1 << "*ldb;" << std::endl;
stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
stream << std::endl;
stream << "//Outer loop" << std::endl;
stream << "while(K > 0){" << std::endl;
stream << "while(K > 0)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
stream << LocalBarrier(backend) << ";" << std::endl;
@@ -309,6 +304,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "//Inner loop" << std::endl;
stream << "#pragma unroll" << std::endl;
stream << "for(unsigned int k = 0; k < " << p_.kL << "; k+=" << p_.kS << "){" << std::endl;
stream.inc_tab();
@@ -394,14 +390,20 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream.dec_tab();
stream << "}" << std::endl;
stream << std::endl;
stream << "// Offset C" << std::endl;
stream << "C += offc;" << std::endl;
stream << "C += ids.x" << CSTRIDE1 << ";" << std::endl;
stream << "C += ids.z*" << p_.simd_width << CSTRIDE1 << ";" << std::endl;
stream << "C += ids.y*ldc;" << std::endl;
stream << "C += ids.w*ldc*" << p_.simd_width << ";" << std::endl;
if(has_depth)
stream << "C += gidz*ldc*N;" << std::endl;
stream << std::endl;
stream << "//Write back C" << std::endl;
stream << "M += ids.x;" << std::endl;
stream << "N += ids.y;" << std::endl;
stream << "size_t offx = (ids.x + ids.z*" << p_.simd_width << ")" << ";" << std::endl;
stream << "size_t offy = (ids.y + ids.w*" << p_.simd_width << ");" << std::endl;
stream << "C += " << "offx" << CSTRIDE1 << " + offy*ldc" << (has_depth?" + gidz*ldc*N;":"") << ";" << std::endl;
stream << "N -= offy;" << std::endl;
stream << "M -= offx;" << std::endl;
stream << "M -= ids.z*" << p_.simd_width << ";" << std::endl;
stream << "N -= ids.w*" << p_.simd_width << ";" << std::endl;
stream << "int ibm[" << p_.mS << "];" << std::endl;
for(int_t m=0; m < p_.mS; ++m)
{
@@ -412,13 +414,14 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
for(int_t n=0; n < p_.nS; ++n)
{
string Cj = to_string((n/p_.simd_width)*(p_.local_size_1*p_.simd_width) + n%p_.simd_width);
stream << "if(" << Cj << " >= N) return;" << std::endl;
for(int_t m=0; m < p_.mS; ++m)
stream << "rC[" << m << "][" << n << "] *= alpha;" << std::endl;
for(int_t m=0; m < p_.mS; ++m)
stream << "ibm[" << m << "] = ibm[" << m << "] && (" << Cj << " < N);" << std::endl;
for(int_t m=0; m < p_.mS; ++m)
{
string Ci = to_string((m/p_.simd_width)*(p_.local_size_0*p_.simd_width) + m%p_.simd_width);
stream << "if(ibm[" << m << "]) ";
stream << "if(ibm[" << m << "])";
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "] + select((" << sdtype << ")0, C[" << Ci << CSTRIDE1 << "], beta>0);" << std::endl;
}
if((n+1)%p_.simd_width==0)