GEMM: Cleaned generated GEMM code a little bit

This commit is contained in:
Philippe Tillet
2015-07-14 20:40:29 -07:00
parent 753a9b1f3e
commit 8be02a50c3

View File

@@ -103,7 +103,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
using tools::to_string; using tools::to_string;
driver::backend_type backend = device.backend(); driver::backend_type backend = device.backend();
bool has_depth = p_.depth > 1;
#define VLOAD(offset, ptr) vload(p_.simd_width, sdtype, offset, ptr, backend) #define VLOAD(offset, ptr) vload(p_.simd_width, sdtype, offset, ptr, backend)
#define VSTORE(value, offset, ptr) vstore(p_.simd_width, sdtype, value, offset, ptr, backend) #define VSTORE(value, offset, ptr) vstore(p_.simd_width, sdtype, value, offset, ptr, backend)
#define ASTRIDE1 string(check_bounds_?"*Astride1":"") #define ASTRIDE1 string(check_bounds_?"*Astride1":"")
@@ -139,28 +139,31 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
} }
stream << KernelPrefix(backend) << " void " << gemm_name << "(" << _size_t << " M, " << _size_t << " N, " << _size_t << " K, " stream << KernelPrefix(backend) << " void " << gemm_name << "(" << _size_t << " M, " << _size_t << " N, " << _size_t << " K, "
<< Global(backend) << " " << sdtype << "* C, " << _size_t << " Cld," << _size_t << " Coff," << _size_t << " Cstride1, " << Global(backend) << " " << sdtype << "* C, " << _size_t << " ldc," << _size_t << " offc," << _size_t << " Cstride1, "
<< sdtype << " alpha," << sdtype << " alpha,"
<< Global(backend) << " " << sdtype << "* A, " << _size_t << " Ald," << _size_t << " Aoff," << _size_t << " Astride1," << Global(backend) << " " << sdtype << "* A, " << _size_t << " lda," << _size_t << " offa," << _size_t << " Astride1,"
<< Global(backend) << " " << sdtype << "* B, " << _size_t << " Bld," << _size_t << " Boff," << _size_t << " Bstride1," << Global(backend) << " " << sdtype << "* B, " << _size_t << " ldb," << _size_t << " offb," << _size_t << " Bstride1,"
<< sdtype << " beta)" << sdtype << " beta)"
<< std::endl; << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
stream << "A += Aoff;" << std::endl; stream << sdtype << " rC[" << p_.mS << "][" << p_.nS << "];" << std::endl;
stream << "B += Boff;" << std::endl;
stream << "C += Coff;" << std::endl;
stream << sdtype << " rC[" << p_.mS << "][" << p_.nS << "] = {{(" << sdtype << ")0}};" << std::endl;
stream << vdtype << " rA[" << p_.kS << "][" << p_.mS/p_.simd_width << "];" << std::endl; stream << vdtype << " rA[" << p_.kS << "][" << p_.mS/p_.simd_width << "];" << std::endl;
stream << vdtype << " rB[" << p_.kS << "][" << p_.nS/p_.simd_width << "];" << std::endl; stream << vdtype << " rB[" << p_.kS << "][" << p_.nS/p_.simd_width << "];" << std::endl;
for(int_t m=0; m < p_.mS; ++m)
for(int_t n=0; n < p_.nS; ++n)
stream << "rC[" << m << "][" << n << "] = 0;" << std::endl;
stream << "A += offa;" << std::endl;
stream << "B += offb;" << std::endl;
stream << "C += offc;" << std::endl;
///Result Values ///Result Values
size_t lAld = (A_trans_=='N')?p_.mL:p_.kL; size_t llda = (A_trans_=='N')?p_.mL:p_.kL;
stream << Local(backend) << " " << sdtype << " lA[" << p_.kL*p_.mL << "];" << std::endl; stream << Local(backend) << " " << sdtype << " lA[" << p_.kL*p_.mL << "];" << std::endl;
size_t lBld = (B_trans_=='T')?p_.nL:p_.kL; size_t lldb = (B_trans_=='T')?p_.nL:p_.kL;
stream << Local(backend) << " " << sdtype << " lB[" << p_.kL*p_.nL << "];" << std::endl; stream << Local(backend) << " " << sdtype << " lB[" << p_.kL*p_.nL << "];" << std::endl;
stream << std::endl; stream << std::endl;
@@ -169,18 +172,17 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "size_t idx = " << LocalIdx0(backend) << ";" << std::endl; stream << "size_t idx = " << LocalIdx0(backend) << ";" << std::endl;
stream << "size_t idy = " << LocalIdx1(backend) << ";" << std::endl; stream << "size_t idy = " << LocalIdx1(backend) << ";" << std::endl;
if(p_.depth > 1){ if(has_depth)
{
stream << "size_t gidz = " << GroupIdx2(backend) << ";" << std::endl; stream << "size_t gidz = " << GroupIdx2(backend) << ";" << std::endl;
stream << "size_t chunk_size = K/" << p_.depth << ";" << std::endl; stream << "size_t chunk_size = K/" << p_.depth << ";" << std::endl;
stream << "size_t offz = chunk_size*gidz;" << std::endl; stream << "size_t offz = chunk_size*gidz;" << std::endl;
} }
else{ else
stream << "size_t gidz = 0;" << std::endl; {
stream << "size_t chunk_size = K;" << std::endl; stream << "size_t chunk_size = K;" << std::endl;
stream << "size_t offz = 0;" << std::endl;
} }
stream << std::endl; stream << std::endl;
stream << "size_t idt = " << p_.local_size_0 << "*idy + idx;" << std::endl; stream << "size_t idt = " << p_.local_size_0 << "*idy + idx;" << std::endl;
stream << "size_t idxT = idt % " << p_.local_fetch_0 << ";" << std::endl; stream << "size_t idxT = idt % " << p_.local_fetch_0 << ";" << std::endl;
@@ -191,14 +193,14 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
unsigned int npB = p_.nL/(B_trans_=='T'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1); unsigned int npB = p_.nL/(B_trans_=='T'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
if (A_trans_=='N') if (A_trans_=='N')
stream << "A += (idxT*" << p_.simd_width << " + gidx*" << p_.mL<< ")" << ASTRIDE1 << " + idyT*Ald + offz*Ald;" << std::endl; stream << "A += (idxT*" << p_.simd_width << " + gidx*" << p_.mL<< ")" << ASTRIDE1 << " + idyT*lda" << (has_depth?"+ offz*lda":"") << ";" << std::endl;
else else
stream << "A += idxT*" << p_.simd_width << ASTRIDE1 << " + (idyT + gidx*" << p_.mL/p_.simd_width << ")*Ald + offz;" << std::endl; stream << "A += idxT*" << p_.simd_width << ASTRIDE1 << " + (idyT + gidx*" << p_.mL/p_.simd_width << ")*lda" << (has_depth?"+ offz":"") << ";" << std::endl;
if(B_trans_=='T') if(B_trans_=='T')
stream << "B += (idxT*" << p_.simd_width << " + gidy*" << p_.nL << ")" << BSTRIDE1 << " + idyT*Bld + offz*Bld;" << std::endl; stream << "B += (idxT*" << p_.simd_width << " + gidy*" << p_.nL << ")" << BSTRIDE1 << " + idyT*ldb" << (has_depth?"+ offz*ldb":"") << ";" << std::endl;
else else
stream << "B += idxT*" << p_.simd_width << BSTRIDE1 << " + (idyT + gidy*" << p_.nL << ")*Bld + offz;" << std::endl; stream << "B += idxT*" << p_.simd_width << BSTRIDE1 << " + (idyT + gidy*" << p_.nL << ")*ldb" << (has_depth?"+ offz":"") << ";" << std::endl;
stream << "__global " << sdtype << "* Ai[" << npA << "];" << std::endl; stream << "__global " << sdtype << "* Ai[" << npA << "];" << std::endl;
@@ -208,7 +210,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
if (A_trans_=='N') if (A_trans_=='N')
stream << "if(gidx*" << p_.mL << " + idxT*" << p_.simd_width << " + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < M) Ai[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << ASTRIDE1 << ";" << std::endl; stream << "if(gidx*" << p_.mL << " + idxT*" << p_.simd_width << " + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < M) Ai[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << ASTRIDE1 << ";" << std::endl;
else else
stream << "if(gidx*" << p_.mL << " + idyT + " << i << "*" << p_.local_fetch_1 << " < M) Ai[" << i << "] += " << i*p_.local_fetch_1 << "*Ald;" << std::endl; stream << "if(gidx*" << p_.mL << " + idyT + " << i << "*" << p_.local_fetch_1 << " < M) Ai[" << i << "] += " << i*p_.local_fetch_1 << "*lda;" << std::endl;
stream << "__global " << sdtype << "* Bi[" << npB << "];" << std::endl; stream << "__global " << sdtype << "* Bi[" << npB << "];" << std::endl;
@@ -219,10 +221,10 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
if (B_trans_=='T') if (B_trans_=='T')
stream << "if(gidy*" << p_.nL << " + idxT* " << p_.simd_width << " + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << BSTRIDE1 << ";" << std::endl; stream << "if(gidy*" << p_.nL << " + idxT* " << p_.simd_width << " + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << BSTRIDE1 << ";" << std::endl;
else else
stream << "if(gidy*" << p_.nL << " + idyT + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += " << i*p_.local_fetch_1 << "*Bld;" << std::endl; stream << "if(gidy*" << p_.nL << " + idyT + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += " << i*p_.local_fetch_1 << "*ldb;" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* lAstore = lA + idyT*" << lAld << " + idxT*" << p_.simd_width << ";" << std::endl; stream << LocalPtr(backend) << " " << sdtype << "* lAstore = lA + idyT*" << llda << " + idxT*" << p_.simd_width << ";" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* lBstore = lB + idyT*" << lBld << " + idxT*" << p_.simd_width << ";" << std::endl; stream << LocalPtr(backend) << " " << sdtype << "* lBstore = lB + idyT*" << lldb << " + idxT*" << p_.simd_width << ";" << std::endl;
stream << "//Outer loop" << std::endl; stream << "//Outer loop" << std::endl;
stream << "for(size_t block_k=0; block_k < chunk_size ; block_k+=" << p_.kL << "){" << std::endl; stream << "for(size_t block_k=0; block_k < chunk_size ; block_k+=" << p_.kL << "){" << std::endl;
@@ -232,17 +234,20 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "//Fetch A to local memory" << std::endl; stream << "//Fetch A to local memory" << std::endl;
if (A_trans_=='N') if (A_trans_=='N')
{
for(int_t k = 0; k < p_.kL; k += p_.local_fetch_1) for(int_t k = 0; k < p_.kL; k += p_.local_fetch_1)
for(int_t m = 0; m < p_.mL; m += p_.local_fetch_0*p_.simd_width) for(int_t m = 0; m < p_.mL; m += p_.local_fetch_0*p_.simd_width)
{ {
std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0)); std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k); std::string kk = to_string(k);
string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*Ald]"); string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]");
if(check_bounds_) if(check_bounds_)
to_load = "(block_k + idyT + " + kk + "< chunk_size)?" + to_load + ":0"; to_load = "(block_k + idyT + " + kk + "< chunk_size)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lAstore + " + to_string(k*lAld+m)) << ";" << std::endl; stream << VSTORE(to_load, "0", "lAstore + " + to_string(k*llda+m)) << ";" << std::endl;
} }
}
else else
{
for(int_t k = 0; k < p_.kL; k += p_.local_fetch_0*p_.simd_width) for(int_t k = 0; k < p_.kL; k += p_.local_fetch_0*p_.simd_width)
for(int_t m = 0; m < p_.mL; m += p_.local_fetch_1) for(int_t m = 0; m < p_.mL; m += p_.local_fetch_1)
{ {
@@ -251,22 +256,26 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]"); string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]");
if(check_bounds_) if(check_bounds_)
to_load = "(block_k + idxT + " + kk + "< chunk_size)?" + to_load + ":0"; to_load = "(block_k + idxT + " + kk + "< chunk_size)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lAstore + " + to_string(m*lAld+k)) << ";" << std::endl; stream << VSTORE(to_load, "0", "lAstore + " + to_string(m*llda+k)) << ";" << std::endl;
} }
}
stream << "//Fetch B to local memory" << std::endl; stream << "//Fetch B to local memory" << std::endl;
if (B_trans_=='T') if (B_trans_=='T')
{
for(int_t k = 0; k < p_.kL; k += p_.local_fetch_1) for(int_t k = 0; k < p_.kL; k += p_.local_fetch_1)
for(int_t n = 0; n < p_.nL; n += p_.local_fetch_0*p_.simd_width) for(int_t n = 0; n < p_.nL; n += p_.local_fetch_0*p_.simd_width)
{ {
std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0)); std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k); std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*Bld]"); string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*ldb]");
if(check_bounds_) if(check_bounds_)
to_load = "(block_k + idyT + " + kk + "< chunk_size)?" + to_load + ":0"; to_load = "(block_k + idyT + " + kk + "< chunk_size)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lBstore + " + to_string(k*lBld+n)) << ";" << std::endl; stream << VSTORE(to_load, "0", "lBstore + " + to_string(k*lldb+n)) << ";" << std::endl;
} }
}
else else
{
for(int_t k = 0; k < p_.kL; k += p_.local_fetch_0*p_.simd_width) for(int_t k = 0; k < p_.kL; k += p_.local_fetch_0*p_.simd_width)
for(int_t n = 0; n < p_.nL; n += p_.local_fetch_1) for(int_t n = 0; n < p_.nL; n += p_.local_fetch_1)
{ {
@@ -275,19 +284,20 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]"); string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]");
if(check_bounds_) if(check_bounds_)
to_load = "(block_k + idxT + " + kk + "< chunk_size)?" + to_load + ":0"; to_load = "(block_k + idxT + " + kk + "< chunk_size)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lBstore + " + to_string(n*lBld+k)) << ";" << std::endl; stream << VSTORE(to_load, "0", "lBstore + " + to_string(n*lldb+k)) << ";" << std::endl;
} }
}
stream << LocalBarrier(backend) << ";" << std::endl; stream << LocalBarrier(backend) << ";" << std::endl;
if(A_trans_=='N') if(A_trans_=='N')
stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + idx*" << p_.simd_width << ";" << std::endl; stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + idx*" << p_.simd_width << ";" << std::endl;
else else
stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + idx*" << lAld*p_.simd_width << ";" << std::endl; stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + idx*" << llda*p_.simd_width << ";" << std::endl;
if(B_trans_=='T') if(B_trans_=='T')
stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + idy*" << p_.simd_width << ";" << std::endl; stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + idy*" << p_.simd_width << ";" << std::endl;
else else
stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + idy*" << lBld*p_.simd_width << ";" << std::endl; stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + idy*" << lldb*p_.simd_width << ";" << std::endl;
stream << "//Inner loop" << std::endl; stream << "//Inner loop" << std::endl;
@@ -302,14 +312,14 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
if(A_trans_=='N') if(A_trans_=='N')
stream << "rA[kk][mm] = " << VLOAD("0", "readA + k*" + to_string(lAld) + " + mm*" + to_string(p_.local_size_0*p_.simd_width) + "+ kk*" + to_string(lAld)) << ";" << std::endl; stream << "rA[kk][mm] = " << VLOAD("0", "readA + k*" + to_string(llda) + " + mm*" + to_string(p_.local_size_0*p_.simd_width) + "+ kk*" + to_string(llda)) << ";" << std::endl;
else else
{ {
if(p_.simd_width==1) if(p_.simd_width==1)
stream << "rA[kk][mm] = readA[k + mm*" << p_.local_size_0*lAld << "+ kk" << "];" << std::endl; stream << "rA[kk][mm] = readA[k + mm*" << p_.local_size_0*llda << "+ kk" << "];" << std::endl;
else else
for(unsigned int s = 0 ; s < p_.simd_width ; ++s) for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
stream << access_vector_type("rA[kk][mm]", s) << " = readA[k + (mm*" << p_.simd_width*p_.local_size_0 << " + " << s << ")*" << lAld << "+ kk];" << std::endl; stream << access_vector_type("rA[kk][mm]", s) << " = readA[k + (mm*" << p_.simd_width*p_.local_size_0 << " + " << s << ")*" << llda << "+ kk];" << std::endl;
} }
stream.dec_tab(); stream.dec_tab();
@@ -323,14 +333,14 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
if(B_trans_=='T') if(B_trans_=='T')
stream << "rB[kk][nn] = " << VLOAD("0", "readB + k*" + to_string(lBld) + " + nn*" + to_string(p_.local_size_1*p_.simd_width) + "+ kk*" + to_string(lBld)) << ";" << std::endl; stream << "rB[kk][nn] = " << VLOAD("0", "readB + k*" + to_string(lldb) + " + nn*" + to_string(p_.local_size_1*p_.simd_width) + "+ kk*" + to_string(lldb)) << ";" << std::endl;
else else
{ {
if(p_.simd_width==1) if(p_.simd_width==1)
stream << "rB[kk][nn] = readB[k" << " + nn*" << p_.local_size_1*lBld << "+ kk" << "];" << std::endl; stream << "rB[kk][nn] = readB[k" << " + nn*" << p_.local_size_1*lldb << "+ kk" << "];" << std::endl;
else else
for(unsigned int s = 0 ; s < p_.simd_width ; ++s) for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
stream << access_vector_type("rB[kk][nn]", s) << " = readB[k" << " + (nn*" << p_.simd_width*p_.local_size_1 << " + " << s << ")*" << lBld << "+ kk];" << std::endl; stream << access_vector_type("rB[kk][nn]", s) << " = readB[k" << " + (nn*" << p_.simd_width*p_.local_size_1 << " + " << s << ")*" << lldb << "+ kk];" << std::endl;
} }
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
@@ -359,7 +369,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
//Increment A pointers to global memory //Increment A pointers to global memory
if (A_trans_=='N') if (A_trans_=='N')
for(unsigned int i = 0 ; i < npA ; ++i) for(unsigned int i = 0 ; i < npA ; ++i)
stream << "Ai[" << i << "] += " << p_.kL << "*Ald;" << std::endl; stream << "Ai[" << i << "] += " << p_.kL << "*lda;" << std::endl;
else else
for(unsigned int i = 0 ; i < npA ; ++i) for(unsigned int i = 0 ; i < npA ; ++i)
stream << "Ai[" << i << "] += " << p_.kL << ASTRIDE1 << ";" << std::endl; stream << "Ai[" << i << "] += " << p_.kL << ASTRIDE1 << ";" << std::endl;
@@ -367,7 +377,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
//Increment B pointers to global memory //Increment B pointers to global memory
if (B_trans_=='T') if (B_trans_=='T')
for(unsigned int i = 0 ; i < npB ; ++i) for(unsigned int i = 0 ; i < npB ; ++i)
stream << "Bi[" << i << "] += " << p_.kL << "*Bld;" << std::endl; stream << "Bi[" << i << "] += " << p_.kL << "*ldb;" << std::endl;
else else
for(unsigned int i = 0 ; i < npB ; ++i) for(unsigned int i = 0 ; i < npB ; ++i)
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl; stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
@@ -375,13 +385,20 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
stream << "//Write back C" << std::endl; stream << "//Write back C" << std::endl;
//alpha
for(int_t m=0; m < p_.mS; ++m)
for(int_t n=0; n < p_.nS; ++n)
stream << "rC[" << m << "][" << n << "] *= alpha;" << std::endl;
//beta
unsigned int ministartstride0 = p_.simd_width; unsigned int ministartstride0 = p_.simd_width;
unsigned int ministartstride1 = p_.simd_width; unsigned int ministartstride1 = p_.simd_width;
stream << "size_t offx = (gidx*" << p_.mL << " + idx*" << ministartstride0 << ")" << ";" << std::endl; stream << "size_t offx = (gidx*" << p_.mL << " + idx*" << ministartstride0 << ")" << ";" << std::endl;
stream << "size_t offy = (gidy*" << p_.nL << " + idy*" << ministartstride1 << ");" << std::endl; stream << "size_t offy = (gidy*" << p_.nL << " + idy*" << ministartstride1 << ");" << std::endl;
stream << "C += " << "offx" << CSTRIDE1 << " + offy*Cld + gidz*Cld*N;" << std::endl; stream << "C += " << "offx" << CSTRIDE1 << " + offy*ldc" << (has_depth?" + gidz*ldc*N;":"") << ";" << std::endl;
stream << std::endl;
for(int_t m=0; m < p_.mS; ++m) for(int_t m=0; m < p_.mS; ++m)
for(int_t n=0; n < p_.nS; ++n) for(int_t n=0; n < p_.nS; ++n)
{ {
@@ -391,24 +408,24 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
string Ci = to_string((m/p_.simd_width)*(ministride0*p_.simd_width) + m%p_.simd_width); string Ci = to_string((m/p_.simd_width)*(ministride0*p_.simd_width) + m%p_.simd_width);
string Cj = to_string((n/p_.simd_width)*(ministride1*p_.simd_width) + n%p_.simd_width); string Cj = to_string((n/p_.simd_width)*(ministride1*p_.simd_width) + n%p_.simd_width);
stream << "if((offx + " << Ci << ")<M && (" << Cj << " + offy)<N)"<< std::flush; stream << "if((offx + " << Ci << ")<M && (" << Cj << " + offy)<N)"<< std::flush;
stream << "C[" << Ci << CSTRIDE1 << " + " << Cj << "*Cld] = rC[" << m << "][" << n << "]*alpha + ((beta==0)?0:beta*C[" << Ci << " + " << Cj << "*Cld]);" << std::endl; stream << "C[" << Ci << CSTRIDE1 << " + " << Cj << "*ldc] = rC[" << m << "][" << n << "] + ((beta==0)?0:beta*C[" << Ci << " + " << Cj << "*ldc]);" << std::endl;
} }
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
if(p_.depth > 1) if(has_depth)
{ {
stream << KernelPrefix(backend) << " void " << reduce_name << "(" << _size_t << " M, " << _size_t << " N, " << _size_t << " D, " stream << KernelPrefix(backend) << " void " << reduce_name << "(" << _size_t << " M, " << _size_t << " N, " << _size_t << " D, "
<< Global(backend) << " " << sdtype << "* Z, " << _size_t << " Zld," << Global(backend) << " " << sdtype << "* Z, " << _size_t << " Zld,"
<< Global(backend) << " " << sdtype << "* C, " << _size_t << " Cld," << _size_t << " Cstart1," << _size_t << " Cstart2," << _size_t << " Cstride1, " << _size_t << " Cstride2, " << Global(backend) << " " << sdtype << "* C, " << _size_t << " ldc," << _size_t << " Cstart1," << _size_t << " Cstart2," << _size_t << " Cstride1, " << _size_t << " Cstride2, "
<< sdtype << " beta)" << sdtype << " beta)"
<< std::endl; << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
stream << "C += Cstart1 + Cstart2*Cld;" << std::endl; stream << "C += Cstart1 + Cstart2*ldc;" << std::endl;
stream << "Cld *= Cstride2;" << std::endl; stream << "ldc *= Cstride2;" << std::endl;
stream << "for(unsigned int i = " << GlobalIdx0(backend) << " ; i < M ; i += " << GlobalSize0(backend) << ")" << std::endl; stream << "for(unsigned int i = " << GlobalIdx0(backend) << " ; i < M ; i += " << GlobalSize0(backend) << ")" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
@@ -420,7 +437,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream.inc_tab(); stream.inc_tab();
stream << "acc += Z[i + j*Zld + k*Zld*N];" << std::endl; stream << "acc += Z[i + j*Zld + k*Zld*N];" << std::endl;
stream.dec_tab(); stream.dec_tab();
stream << "C[i*Cstride1 + j*Cld] = acc + beta*C[i + j*Cld];" << std::endl; stream << "C[i*Cstride1 + j*ldc] = acc + beta*C[i + j*ldc];" << std::endl;
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
stream.dec_tab(); stream.dec_tab();