GEMM: More bugfixes
This commit is contained in:
@@ -49,7 +49,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
if(p_.A_fetching_policy!=FETCH_FROM_LOCAL || p_.B_fetching_policy!=FETCH_FROM_LOCAL)
|
if(p_.A_fetching_policy!=FETCH_FROM_LOCAL || p_.B_fetching_policy!=FETCH_FROM_LOCAL)
|
||||||
throw operation_not_supported_exception("Only local memory is supported for GEMM");
|
throw operation_not_supported_exception("Only local memory is supported for GEMM");
|
||||||
|
|
||||||
if(p_.depth > 1 && M*N*p_.depth > 1e6)
|
if(p_.depth > 1 && M*N*p_.depth > 2e6)
|
||||||
throw operation_not_supported_exception("This would necessitate a temporary larger than 1MB");
|
throw operation_not_supported_exception("This would necessitate a temporary larger than 1MB");
|
||||||
|
|
||||||
if ((p_.mS % p_.simd_width) > 0 || (p_.nS % p_.simd_width) > 0)
|
if ((p_.mS % p_.simd_width) > 0 || (p_.nS % p_.simd_width) > 0)
|
||||||
@@ -192,12 +192,12 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
if (A_trans_=='N')
|
if (A_trans_=='N')
|
||||||
stream << "A += (idxT*" << p_.simd_width << " + gidx*" << p_.mL<< ")" << ASTRIDE1 << " + idyT*Ald + offz*Ald;" << std::endl;
|
stream << "A += (idxT*" << p_.simd_width << " + gidx*" << p_.mL<< ")" << ASTRIDE1 << " + idyT*Ald + offz*Ald;" << std::endl;
|
||||||
else
|
else
|
||||||
stream << "A += idxT" << ASTRIDE1 << " + (idyT + gidx*" << p_.mL/p_.simd_width << ")*Ald + offz;" << std::endl;
|
stream << "A += (idxT)*" << p_.simd_width << ASTRIDE1 << " + (idyT + gidx*" << p_.mL/p_.simd_width << ")*Ald + offz;" << std::endl;
|
||||||
|
|
||||||
if(B_trans_=='T')
|
if(B_trans_=='T')
|
||||||
stream << "B += (idxT*" << p_.simd_width << " + gidy*" << p_.nL << ")" << BSTRIDE1 << " + idyT*Bld + offz*Bld;" << std::endl;
|
stream << "B += (idxT*" << p_.simd_width << " + gidy*" << p_.nL << ")" << BSTRIDE1 << " + idyT*Bld + offz*Bld;" << std::endl;
|
||||||
else
|
else
|
||||||
stream << "B += idxT" << BSTRIDE1 << " + (idyT + gidy*" << p_.nL << ")*Bld + offz;" << std::endl;
|
stream << "B += (idxT)*" << p_.simd_width << BSTRIDE1 << " + (idyT + gidy*" << p_.nL << ")*Bld + offz;" << std::endl;
|
||||||
|
|
||||||
|
|
||||||
stream << "__global " << sdtype << "* Ai[" << npA << "];" << std::endl;
|
stream << "__global " << sdtype << "* Ai[" << npA << "];" << std::endl;
|
||||||
@@ -249,7 +249,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
std::string kk = to_string(k);
|
std::string kk = to_string(k);
|
||||||
string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*Ald]");
|
string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*Ald]");
|
||||||
if(check_bounds_)
|
if(check_bounds_)
|
||||||
to_load = "(block_k + idyT + " + kk + "< K)?" + to_load + ":0";
|
to_load = "(block_k + idyT + " + kk + "< chunk_size)?" + to_load + ":0";
|
||||||
stream << VSTORE(to_load, "0", "lAstore + " + to_string(k*lAld+m)) << ";" << std::endl;
|
stream << VSTORE(to_load, "0", "lAstore + " + to_string(k*lAld+m)) << ";" << std::endl;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@@ -258,16 +258,16 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
{
|
{
|
||||||
std::string mm = to_string(k/p_.local_fetch_1);
|
std::string mm = to_string(k/p_.local_fetch_1);
|
||||||
std::string kk = to_string(m/p_.simd_width);
|
std::string kk = to_string(m/p_.simd_width);
|
||||||
string to_load = "Ai[" + mm + "][" + kk + ASTRIDE1 + "]";
|
string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]");
|
||||||
if(check_bounds_)
|
if(check_bounds_)
|
||||||
to_load = "(block_k + idxT + " + kk + "< K)?" + to_load + ":0";
|
to_load = "(block_k + idxT + " + kk + "< chunk_size)?" + to_load + ":0";
|
||||||
if(p_.simd_width==1)
|
if(p_.simd_width==1)
|
||||||
stream << VSTORE(to_load, "0", "lAstore + " + to_string(m*lAld+k)) << ";" << std::endl;
|
stream << VSTORE(to_load, "0", "lAstore + " + to_string(m*lAld+k)) << ";" << std::endl;
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
stream << vdtype << " tmpA" << k << m << " = " << to_load << ";" << std::endl;
|
stream << vdtype << " tmpA" << k << "_" << m << " = " << to_load << ";" << std::endl;
|
||||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||||
stream << "lAstore[" << k + (m + s)*lAld << "]= tmpA" << k << m << ".s" << s << ";" << std::endl;
|
stream << "lAstore[" << k + (m + s)*lAld << "]= tmpA" << k << "_" << m << ".s" << s << ";" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -281,7 +281,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
std::string kk = to_string(k);
|
std::string kk = to_string(k);
|
||||||
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*Bld]");
|
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*Bld]");
|
||||||
if(check_bounds_)
|
if(check_bounds_)
|
||||||
to_load = "(block_k + idyT + " + kk + "< K)?" + to_load + ":0";
|
to_load = "(block_k + idyT + " + kk + "< chunk_size)?" + to_load + ":0";
|
||||||
stream << VSTORE(to_load, "0", "lBstore + " + to_string(k*lBld+n)) << ";" << std::endl;
|
stream << VSTORE(to_load, "0", "lBstore + " + to_string(k*lBld+n)) << ";" << std::endl;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@@ -289,17 +289,17 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
for(int_t n = 0; n < p_.kL; n += p_.local_fetch_0*p_.simd_width)
|
for(int_t n = 0; n < p_.kL; n += p_.local_fetch_0*p_.simd_width)
|
||||||
{
|
{
|
||||||
std::string nn = to_string(k/p_.local_fetch_1);
|
std::string nn = to_string(k/p_.local_fetch_1);
|
||||||
std::string kk = to_string(n/p_.simd_width);
|
std::string kk = to_string(n);
|
||||||
string to_load = "Bi[" + nn + "][" + kk + BSTRIDE1 + "]";
|
string to_load = VLOAD("0","&Bi[" + nn + "][" + kk + BSTRIDE1 + "]");
|
||||||
if(check_bounds_)
|
if(check_bounds_)
|
||||||
to_load = "(block_k + idxT + " + kk + "< K)?" + to_load + ":0";
|
to_load = "(block_k + idxT + " + kk + "< chunk_size)?" + to_load + ":0";
|
||||||
if(p_.simd_width==1)
|
if(p_.simd_width==1)
|
||||||
stream << VSTORE(to_load, "0", "lBstore + " + to_string(n*lBld+k)) << ";" << std::endl;
|
stream << VSTORE(to_load, "0", "lBstore + " + to_string(n*lBld+k)) << ";" << std::endl;
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
stream << vdtype << " tmpB" << k << n << " = " << to_load << ";" << std::endl;
|
stream << vdtype << " tmpB" << k << "_" << n << " = " << to_load << ";" << std::endl;
|
||||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||||
stream << "lBstore[" << k + (n + s)*lBld << "]= tmpB" << k << n << ".s" << s << ";" << std::endl;
|
stream << "lBstore[" << k + (n + s)*lBld << "]= tmpB" << k << "_" << n << ".s" << s << ";" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stream << LocalBarrier(backend) << ";" << std::endl;
|
stream << LocalBarrier(backend) << ";" << std::endl;
|
||||||
@@ -361,7 +361,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
stream << "Ai[" << i << "] += " << p_.kL << "*Ald;" << std::endl;
|
stream << "Ai[" << i << "] += " << p_.kL << "*Ald;" << std::endl;
|
||||||
else
|
else
|
||||||
for(unsigned int i = 0 ; i < npA ; ++i)
|
for(unsigned int i = 0 ; i < npA ; ++i)
|
||||||
stream << "Ai[" << i << "] += " << p_.kL/p_.simd_width << ASTRIDE1 << ";" << std::endl;
|
stream << "Ai[" << i << "] += " << p_.kL << ASTRIDE1 << ";" << std::endl;
|
||||||
|
|
||||||
//Increment B pointers to global memory
|
//Increment B pointers to global memory
|
||||||
if (B_trans_=='T')
|
if (B_trans_=='T')
|
||||||
@@ -369,7 +369,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
stream << "Bi[" << i << "] += " << p_.kL << "*Bld;" << std::endl;
|
stream << "Bi[" << i << "] += " << p_.kL << "*Bld;" << std::endl;
|
||||||
else
|
else
|
||||||
for(unsigned int i = 0 ; i < npB ; ++i)
|
for(unsigned int i = 0 ; i < npB ; ++i)
|
||||||
stream << "Bi[" << i << "] += " << p_.kL/p_.simd_width << BSTRIDE1 << ";" << std::endl;
|
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
|
||||||
|
|
||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
@@ -429,8 +429,8 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(p_.simd_width>1)
|
// if(p_.simd_width>1)
|
||||||
std::cout << stream.str() << std::endl;
|
// std::cout << stream.str() << std::endl;
|
||||||
return stream.str();
|
return stream.str();
|
||||||
|
|
||||||
#undef VLOAD
|
#undef VLOAD
|
||||||
@@ -605,6 +605,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
// std::cout << p_.local_size_0 << " " << p_.kL << " " << p_.local_size_1 << " " << p_.depth << std::endl;
|
||||||
value_scalar _1(1, dtype);
|
value_scalar _1(1, dtype);
|
||||||
enqueue_block(queue, M, N, lK, create_slice(*pA, 0, M, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, beta, program, suffix, options);
|
enqueue_block(queue, M, N, lK, create_slice(*pA, 0, M, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, beta, program, suffix, options);
|
||||||
fallback.enqueue_block(queue, M, N, K - lK, create_slice(*pA, 0, M, lK, K, swap_A), create_slice(*pB, lK, K, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, _1, program, "fallback", options);
|
fallback.enqueue_block(queue, M, N, K - lK, create_slice(*pA, 0, M, lK, K, swap_A), create_slice(*pB, lK, K, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, _1, program, "fallback", options);
|
||||||
|
Reference in New Issue
Block a user