GEMM: Fixing bounds checking on K

This commit is contained in:
Philippe Tillet
2015-07-21 14:35:22 -04:00
parent 18663d6a93
commit 33bd3a77fc

View File

@@ -204,7 +204,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
if (A_trans_=='N') if (A_trans_=='N')
{ {
stream << "A += ids.x" << ASTRIDE1 << ";" << std::endl; stream << "A += ids.x" << ASTRIDE1 << ";" << std::endl;
stream << "if(idT.y < K) A += idT.y*lda;" << std::endl; stream << "A += idT.y*lda;" << std::endl;
if(has_depth) if(has_depth)
stream << "A += offz*lda;" << std::endl; stream << "A += offz*lda;" << std::endl;
@@ -212,7 +212,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
else else
{ {
stream << "A += ids.x*lda;" << std::endl; stream << "A += ids.x*lda;" << std::endl;
stream << "if(idT.x < K) A += idT.x" << ASTRIDE1 << ";" << std::endl; stream << "A += idT.x" << ASTRIDE1 << ";" << std::endl;
if(has_depth) if(has_depth)
stream << "A += offz;" << std::endl; stream << "A += offz;" << std::endl;
} }
@@ -220,37 +220,41 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
if(B_trans_=='T') if(B_trans_=='T')
{ {
stream << "B += ids.y" << BSTRIDE1 << ";" << std::endl; stream << "B += ids.y" << BSTRIDE1 << ";" << std::endl;
stream << "if(idT.y < K) B += idT.y*ldb;" << std::endl; stream << "B += idT.y*ldb;" << std::endl;
if(has_depth) if(has_depth)
stream << "B += offz*ldb;" << std::endl; stream << "B += offz*ldb;" << std::endl;
} }
else else
{ {
stream << "B += ids.y*ldb;" << std::endl; stream << "B += ids.y*ldb;" << std::endl;
stream << "if(idT.x < K) B += idT.x" << BSTRIDE1 << ";" << std::endl; stream << "B += idT.x" << BSTRIDE1 << ";" << std::endl;
if(has_depth) if(has_depth)
stream << "B += offz;" << std::endl; stream << "B += offz;" << std::endl;
} }
stream << "#pragma unroll" << std::endl; stream << "#pragma unroll" << std::endl;
stream << "for(int i = 0 ; i < " << npA << " ; ++i) " << std::endl; stream << "for(int i = 0 ; i < " << npA << " ; ++i) " << std::endl;
stream.inc_tab();
stream << "Ai[i] = A;" << std::endl; stream << "Ai[i] = A;" << std::endl;
stream.dec_tab();
stream << "#pragma unroll" << std::endl; stream << "#pragma unroll" << std::endl;
stream << "for(int i = 0 ; i < " << npB << " ; ++i)" << std::endl; stream << "for(int i = 0 ; i < " << npB << " ; ++i)" << std::endl;
stream.inc_tab();
stream << "Bi[i] = B;" << std::endl; stream << "Bi[i] = B;" << std::endl;
stream.dec_tab();
for(unsigned int i = 0 ; i < npA ; i++ ) for(unsigned int i = 0 ; i < npA ; i++ )
if (A_trans_=='N') if (A_trans_=='N')
stream << "if(idT.x + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < M) Ai[" << i << "] += (idT.x + " << i*p_.local_fetch_0*p_.simd_width << ")" << ASTRIDE1 << ";" << std::endl; stream << "if(idT.x + " << i*p_.local_fetch_0*p_.simd_width << " < M) Ai[" << i << "] += (idT.x + " << i*p_.local_fetch_0*p_.simd_width << ")" << ASTRIDE1 << ";" << std::endl;
else else
stream << "if(idT.y + " << i << "*" << p_.local_fetch_1 << " < M) Ai[" << i << "] += (idT.y + " << i*p_.local_fetch_1 << ")*lda;" << std::endl; stream << "if(idT.y + " << i*p_.local_fetch_1 << " < M) Ai[" << i << "] += (idT.y + " << i*p_.local_fetch_1 << ")*lda;" << std::endl;
for(unsigned int i = 0 ; i < npB ; i++ ) for(unsigned int i = 0 ; i < npB ; i++ )
if (B_trans_=='T') if (B_trans_=='T')
stream << "if(idT.x + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += (idT.x + " << i*p_.local_fetch_0*p_.simd_width << ")" << BSTRIDE1 << ";" << std::endl; stream << "if(idT.x + " << i*p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += (idT.x + " << i*p_.local_fetch_0*p_.simd_width << ")" << BSTRIDE1 << ";" << std::endl;
else else
stream << "if(idT.y + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += (idT.y + " << i*p_.local_fetch_1 << ")*ldb;" << std::endl; stream << "if(idT.y + " << i*p_.local_fetch_1 << " < N) Bi[" << i << "] += (idT.y + " << i*p_.local_fetch_1 << ")*ldb;" << std::endl;
stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl; stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl; stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
@@ -265,6 +269,25 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
if(A_trans_=='T' || B_trans_=='N') if(A_trans_=='T' || B_trans_=='N')
stream << "Kx = K - idT.x;" << std::endl; stream << "Kx = K - idT.x;" << std::endl;
std::string vint = append_width("int", p_.simd_width);
if(A_trans_=='N' || B_trans_=='T')
{
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_1)
stream << vint << " condy" << k << " = (" << vint << ")(" << k << ") < Ky;" << std::endl;
}
if(A_trans_=='T' || B_trans_=='N')
{
for(unsigned int k = 0 ; k < p_.kL ; k += p_.local_fetch_0*p_.simd_width)
{
stream << vint << " condx" << k << " = (" << vint << ")(";
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
stream << (s>0?",":"") << k + s;
stream << ") < Kx;" << std::endl;
}
}
stream << "//Fetch A to local memory" << std::endl; stream << "//Fetch A to local memory" << std::endl;
if (A_trans_=='N') if (A_trans_=='N')
{ {
@@ -274,7 +297,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0)); std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k); std::string kk = to_string(k);
string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]"); string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]");
to_load = "(" + kk + "< Ky)?" + to_load + ":0"; to_load = "(" + kk + " < Ky)?select((" + vdtype + ")0, " + to_load + ", condy" + kk + "):0";
stream << VSTORE(to_load, "0", "storeA + " + to_string(k*llda+m)) << ";" << std::endl; stream << VSTORE(to_load, "0", "storeA + " + to_string(k*llda+m)) << ";" << std::endl;
} }
} }
@@ -286,7 +309,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string mm = to_string(m/p_.local_fetch_1); std::string mm = to_string(m/p_.local_fetch_1);
std::string kk = to_string(k); std::string kk = to_string(k);
string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]"); string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]");
to_load = "(" + kk + "< Kx)?" + to_load + ":0"; to_load = "(" + kk + " < Kx)?select((" + vdtype + ")0, " + to_load + ", condx" + kk + "):0";
stream << VSTORE(to_load, "0", "storeA + " + to_string(m*llda+k)) << ";" << std::endl; stream << VSTORE(to_load, "0", "storeA + " + to_string(m*llda+k)) << ";" << std::endl;
} }
} }
@@ -300,7 +323,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0)); std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k); std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*ldb]"); string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*ldb]");
to_load = "(" + kk + "< Ky)?" + to_load + ":0"; to_load = "(" + kk + " < Ky)?select((" + vdtype + ")0, " + to_load + ", condy" + kk + "):0";
stream << VSTORE(to_load, "0", "storeB + " + to_string(k*lldb+n)) << ";" << std::endl; stream << VSTORE(to_load, "0", "storeB + " + to_string(k*lldb+n)) << ";" << std::endl;
} }
} }
@@ -312,7 +335,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string nn = to_string(n/p_.local_fetch_1); std::string nn = to_string(n/p_.local_fetch_1);
std::string kk = to_string(k); std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]"); string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]");
to_load = "(" + kk + "< Kx)?" + to_load + ":0"; to_load = "(" + kk + " < Kx)?select((" + vdtype + ")0, " + to_load + ", condx" + kk + "):0";
stream << VSTORE(to_load, "0", "storeB + " + to_string(n*lldb+k)) << ";" << std::endl; stream << VSTORE(to_load, "0", "storeB + " + to_string(n*lldb+k)) << ";" << std::endl;
} }
} }
@@ -419,11 +442,11 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "//Write back C" << std::endl; stream << "//Write back C" << std::endl;
stream << "M += ids.x;" << std::endl; stream << "M += ids.x;" << std::endl;
stream << "N += ids.y;" << std::endl; stream << "N += ids.y;" << std::endl;
stream << "size_t offx = (ids.x + ids.z*" << p_.simd_width << ")" << ";" << std::endl; stream << _size_t << " offx = (ids.x + ids.z*" << p_.simd_width << ")" << ";" << std::endl;
stream << "size_t offy = (ids.y + ids.w*" << p_.simd_width << ");" << std::endl; stream << _size_t << " offy = (ids.y + ids.w*" << p_.simd_width << ");" << std::endl;
stream << "C += " << "offx" << CSTRIDE1 << " + offy*ldc" << (has_depth?" + gidz*ldc*N;":"") << ";" << std::endl; stream << "C += " << "offx" << CSTRIDE1 << " + offy*ldc" << (has_depth?" + gidz*ldc*N;":"") << ";" << std::endl;
stream << "N -= offy;" << std::endl;
stream << "M -= offx;" << std::endl; stream << "M -= offx;" << std::endl;
stream << "N -= offy;" << std::endl;
stream << "int ibm[" << p_.mS << "];" << std::endl; stream << "int ibm[" << p_.mS << "];" << std::endl;
for(unsigned int m=0; m < p_.mS; ++m) for(unsigned int m=0; m < p_.mS; ++m)
{ {
@@ -488,7 +511,9 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "}" << std::endl; stream << "}" << std::endl;
} }
// std::cout << stream.str() << std::endl;
// if(p_.simd_width>1)
// std::cout << stream.str() << std::endl;
return stream.str(); return stream.str();
#undef VLOAD #undef VLOAD
@@ -612,6 +637,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
void gemm::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr) void gemm::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
{ {
using namespace tools; using namespace tools;
// std::cout << p_.simd_width << " " << p_.mL << " " << p_.kL << " " << p_.mS << " " << p_.depth << " " << p_.local_size_0 << std::endl;
gemm & fallback = (gemm&)fallback_base; gemm & fallback = (gemm&)fallback_base;
expressions_tuple const & expressions = ctr.x(); expressions_tuple const & expressions = ctr.x();