GEMM: Fixing bounds checking on K
This commit is contained in:
@@ -204,7 +204,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
if (A_trans_=='N')
|
if (A_trans_=='N')
|
||||||
{
|
{
|
||||||
stream << "A += ids.x" << ASTRIDE1 << ";" << std::endl;
|
stream << "A += ids.x" << ASTRIDE1 << ";" << std::endl;
|
||||||
stream << "if(idT.y < K) A += idT.y*lda;" << std::endl;
|
stream << "A += idT.y*lda;" << std::endl;
|
||||||
if(has_depth)
|
if(has_depth)
|
||||||
stream << "A += offz*lda;" << std::endl;
|
stream << "A += offz*lda;" << std::endl;
|
||||||
|
|
||||||
@@ -212,7 +212,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
stream << "A += ids.x*lda;" << std::endl;
|
stream << "A += ids.x*lda;" << std::endl;
|
||||||
stream << "if(idT.x < K) A += idT.x" << ASTRIDE1 << ";" << std::endl;
|
stream << "A += idT.x" << ASTRIDE1 << ";" << std::endl;
|
||||||
if(has_depth)
|
if(has_depth)
|
||||||
stream << "A += offz;" << std::endl;
|
stream << "A += offz;" << std::endl;
|
||||||
}
|
}
|
||||||
@@ -220,37 +220,41 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
if(B_trans_=='T')
|
if(B_trans_=='T')
|
||||||
{
|
{
|
||||||
stream << "B += ids.y" << BSTRIDE1 << ";" << std::endl;
|
stream << "B += ids.y" << BSTRIDE1 << ";" << std::endl;
|
||||||
stream << "if(idT.y < K) B += idT.y*ldb;" << std::endl;
|
stream << "B += idT.y*ldb;" << std::endl;
|
||||||
if(has_depth)
|
if(has_depth)
|
||||||
stream << "B += offz*ldb;" << std::endl;
|
stream << "B += offz*ldb;" << std::endl;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
stream << "B += ids.y*ldb;" << std::endl;
|
stream << "B += ids.y*ldb;" << std::endl;
|
||||||
stream << "if(idT.x < K) B += idT.x" << BSTRIDE1 << ";" << std::endl;
|
stream << "B += idT.x" << BSTRIDE1 << ";" << std::endl;
|
||||||
if(has_depth)
|
if(has_depth)
|
||||||
stream << "B += offz;" << std::endl;
|
stream << "B += offz;" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
stream << "#pragma unroll" << std::endl;
|
stream << "#pragma unroll" << std::endl;
|
||||||
stream << "for(int i = 0 ; i < " << npA << " ; ++i) " << std::endl;
|
stream << "for(int i = 0 ; i < " << npA << " ; ++i) " << std::endl;
|
||||||
|
stream.inc_tab();
|
||||||
stream << "Ai[i] = A;" << std::endl;
|
stream << "Ai[i] = A;" << std::endl;
|
||||||
|
stream.dec_tab();
|
||||||
|
|
||||||
stream << "#pragma unroll" << std::endl;
|
stream << "#pragma unroll" << std::endl;
|
||||||
stream << "for(int i = 0 ; i < " << npB << " ; ++i)" << std::endl;
|
stream << "for(int i = 0 ; i < " << npB << " ; ++i)" << std::endl;
|
||||||
|
stream.inc_tab();
|
||||||
stream << "Bi[i] = B;" << std::endl;
|
stream << "Bi[i] = B;" << std::endl;
|
||||||
|
stream.dec_tab();
|
||||||
|
|
||||||
for(unsigned int i = 0 ; i < npA ; i++ )
|
for(unsigned int i = 0 ; i < npA ; i++ )
|
||||||
if (A_trans_=='N')
|
if (A_trans_=='N')
|
||||||
stream << "if(idT.x + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < M) Ai[" << i << "] += (idT.x + " << i*p_.local_fetch_0*p_.simd_width << ")" << ASTRIDE1 << ";" << std::endl;
|
stream << "if(idT.x + " << i*p_.local_fetch_0*p_.simd_width << " < M) Ai[" << i << "] += (idT.x + " << i*p_.local_fetch_0*p_.simd_width << ")" << ASTRIDE1 << ";" << std::endl;
|
||||||
else
|
else
|
||||||
stream << "if(idT.y + " << i << "*" << p_.local_fetch_1 << " < M) Ai[" << i << "] += (idT.y + " << i*p_.local_fetch_1 << ")*lda;" << std::endl;
|
stream << "if(idT.y + " << i*p_.local_fetch_1 << " < M) Ai[" << i << "] += (idT.y + " << i*p_.local_fetch_1 << ")*lda;" << std::endl;
|
||||||
|
|
||||||
for(unsigned int i = 0 ; i < npB ; i++ )
|
for(unsigned int i = 0 ; i < npB ; i++ )
|
||||||
if (B_trans_=='T')
|
if (B_trans_=='T')
|
||||||
stream << "if(idT.x + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += (idT.x + " << i*p_.local_fetch_0*p_.simd_width << ")" << BSTRIDE1 << ";" << std::endl;
|
stream << "if(idT.x + " << i*p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += (idT.x + " << i*p_.local_fetch_0*p_.simd_width << ")" << BSTRIDE1 << ";" << std::endl;
|
||||||
else
|
else
|
||||||
stream << "if(idT.y + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += (idT.y + " << i*p_.local_fetch_1 << ")*ldb;" << std::endl;
|
stream << "if(idT.y + " << i*p_.local_fetch_1 << " < N) Bi[" << i << "] += (idT.y + " << i*p_.local_fetch_1 << ")*ldb;" << std::endl;
|
||||||
|
|
||||||
stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
|
stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
|
||||||
stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
|
stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
|
||||||
@@ -265,6 +269,25 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
if(A_trans_=='T' || B_trans_=='N')
|
if(A_trans_=='T' || B_trans_=='N')
|
||||||
stream << "Kx = K - idT.x;" << std::endl;
|
stream << "Kx = K - idT.x;" << std::endl;
|
||||||
|
|
||||||
|
std::string vint = append_width("int", p_.simd_width);
|
||||||
|
|
||||||
|
if(A_trans_=='N' || B_trans_=='T')
|
||||||
|
{
|
||||||
|
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_1)
|
||||||
|
stream << vint << " condy" << k << " = (" << vint << ")(" << k << ") < Ky;" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(A_trans_=='T' || B_trans_=='N')
|
||||||
|
{
|
||||||
|
for(unsigned int k = 0 ; k < p_.kL ; k += p_.local_fetch_0*p_.simd_width)
|
||||||
|
{
|
||||||
|
stream << vint << " condx" << k << " = (" << vint << ")(";
|
||||||
|
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||||
|
stream << (s>0?",":"") << k + s;
|
||||||
|
stream << ") < Kx;" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
stream << "//Fetch A to local memory" << std::endl;
|
stream << "//Fetch A to local memory" << std::endl;
|
||||||
if (A_trans_=='N')
|
if (A_trans_=='N')
|
||||||
{
|
{
|
||||||
@@ -274,7 +297,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0));
|
std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0));
|
||||||
std::string kk = to_string(k);
|
std::string kk = to_string(k);
|
||||||
string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]");
|
string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]");
|
||||||
to_load = "(" + kk + "< Ky)?" + to_load + ":0";
|
to_load = "(" + kk + " < Ky)?select((" + vdtype + ")0, " + to_load + ", condy" + kk + "):0";
|
||||||
stream << VSTORE(to_load, "0", "storeA + " + to_string(k*llda+m)) << ";" << std::endl;
|
stream << VSTORE(to_load, "0", "storeA + " + to_string(k*llda+m)) << ";" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -286,7 +309,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
std::string mm = to_string(m/p_.local_fetch_1);
|
std::string mm = to_string(m/p_.local_fetch_1);
|
||||||
std::string kk = to_string(k);
|
std::string kk = to_string(k);
|
||||||
string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]");
|
string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]");
|
||||||
to_load = "(" + kk + "< Kx)?" + to_load + ":0";
|
to_load = "(" + kk + " < Kx)?select((" + vdtype + ")0, " + to_load + ", condx" + kk + "):0";
|
||||||
stream << VSTORE(to_load, "0", "storeA + " + to_string(m*llda+k)) << ";" << std::endl;
|
stream << VSTORE(to_load, "0", "storeA + " + to_string(m*llda+k)) << ";" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -300,7 +323,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0));
|
std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0));
|
||||||
std::string kk = to_string(k);
|
std::string kk = to_string(k);
|
||||||
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*ldb]");
|
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*ldb]");
|
||||||
to_load = "(" + kk + "< Ky)?" + to_load + ":0";
|
to_load = "(" + kk + " < Ky)?select((" + vdtype + ")0, " + to_load + ", condy" + kk + "):0";
|
||||||
stream << VSTORE(to_load, "0", "storeB + " + to_string(k*lldb+n)) << ";" << std::endl;
|
stream << VSTORE(to_load, "0", "storeB + " + to_string(k*lldb+n)) << ";" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -312,7 +335,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
std::string nn = to_string(n/p_.local_fetch_1);
|
std::string nn = to_string(n/p_.local_fetch_1);
|
||||||
std::string kk = to_string(k);
|
std::string kk = to_string(k);
|
||||||
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]");
|
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]");
|
||||||
to_load = "(" + kk + "< Kx)?" + to_load + ":0";
|
to_load = "(" + kk + " < Kx)?select((" + vdtype + ")0, " + to_load + ", condx" + kk + "):0";
|
||||||
stream << VSTORE(to_load, "0", "storeB + " + to_string(n*lldb+k)) << ";" << std::endl;
|
stream << VSTORE(to_load, "0", "storeB + " + to_string(n*lldb+k)) << ";" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -419,11 +442,11 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
stream << "//Write back C" << std::endl;
|
stream << "//Write back C" << std::endl;
|
||||||
stream << "M += ids.x;" << std::endl;
|
stream << "M += ids.x;" << std::endl;
|
||||||
stream << "N += ids.y;" << std::endl;
|
stream << "N += ids.y;" << std::endl;
|
||||||
stream << "size_t offx = (ids.x + ids.z*" << p_.simd_width << ")" << ";" << std::endl;
|
stream << _size_t << " offx = (ids.x + ids.z*" << p_.simd_width << ")" << ";" << std::endl;
|
||||||
stream << "size_t offy = (ids.y + ids.w*" << p_.simd_width << ");" << std::endl;
|
stream << _size_t << " offy = (ids.y + ids.w*" << p_.simd_width << ");" << std::endl;
|
||||||
stream << "C += " << "offx" << CSTRIDE1 << " + offy*ldc" << (has_depth?" + gidz*ldc*N;":"") << ";" << std::endl;
|
stream << "C += " << "offx" << CSTRIDE1 << " + offy*ldc" << (has_depth?" + gidz*ldc*N;":"") << ";" << std::endl;
|
||||||
stream << "N -= offy;" << std::endl;
|
|
||||||
stream << "M -= offx;" << std::endl;
|
stream << "M -= offx;" << std::endl;
|
||||||
|
stream << "N -= offy;" << std::endl;
|
||||||
stream << "int ibm[" << p_.mS << "];" << std::endl;
|
stream << "int ibm[" << p_.mS << "];" << std::endl;
|
||||||
for(unsigned int m=0; m < p_.mS; ++m)
|
for(unsigned int m=0; m < p_.mS; ++m)
|
||||||
{
|
{
|
||||||
@@ -488,7 +511,9 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
// std::cout << stream.str() << std::endl;
|
|
||||||
|
// if(p_.simd_width>1)
|
||||||
|
// std::cout << stream.str() << std::endl;
|
||||||
return stream.str();
|
return stream.str();
|
||||||
|
|
||||||
#undef VLOAD
|
#undef VLOAD
|
||||||
@@ -612,6 +637,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
void gemm::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
|
void gemm::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
|
||||||
{
|
{
|
||||||
using namespace tools;
|
using namespace tools;
|
||||||
|
// std::cout << p_.simd_width << " " << p_.mL << " " << p_.kL << " " << p_.mS << " " << p_.depth << " " << p_.local_size_0 << std::endl;
|
||||||
|
|
||||||
gemm & fallback = (gemm&)fallback_base;
|
gemm & fallback = (gemm&)fallback_base;
|
||||||
expressions_tuple const & expressions = ctr.x();
|
expressions_tuple const & expressions = ctr.x();
|
||||||
|
Reference in New Issue
Block a user