|
|
|
@@ -200,77 +200,30 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|
|
|
|
}
|
|
|
|
|
stream << std::endl;
|
|
|
|
|
|
|
|
|
|
if (check_bounds_)
|
|
|
|
|
{
|
|
|
|
|
//Bounds checking for M (in A, C)
|
|
|
|
|
stream << "bool in_bounds_m[" << p_.mS << "] = {" ;
|
|
|
|
|
for(unsigned int m = 0; m < p_.mS ; m++)
|
|
|
|
|
{
|
|
|
|
|
if(m > 0) stream << ",";
|
|
|
|
|
switch(p_.A_fetching_policy)
|
|
|
|
|
{
|
|
|
|
|
case FETCH_FROM_GLOBAL_CONTIGUOUS:
|
|
|
|
|
stream << "gidx*" << p_.mL << " + idx*" << p_.mS << " + " << m << "< M";
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
stream << "gidx*" << p_.mL << " + idx + " << m * p_.local_size_0 << " < M";
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
stream << "};" << std::endl;
|
|
|
|
|
unsigned int npA = p_.mL/(A_trans_=='N'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
|
|
|
|
|
unsigned int npB = p_.nL/(B_trans_=='T'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
|
|
|
|
|
if (A_trans_=='N')
|
|
|
|
|
stream << "__global " << vdtype << "* Ai[" << npA << "] = {A + (gidx*" << p_.mL/p_.simd_width << ")" << ASTRIDE1 << " + idyT*Ald + offz*Ald};" << std::endl;
|
|
|
|
|
else
|
|
|
|
|
stream << "__global " << vdtype << "* Ai[" << npA << "] = {A + idxT" << ASTRIDE1 << " + gidx*" << p_.mL/p_.simd_width << "*Ald + offz};" << std::endl;
|
|
|
|
|
|
|
|
|
|
//Bounds checking for N (in B, C)
|
|
|
|
|
stream << "bool in_bounds_n[" << p_.nS << "] = {";
|
|
|
|
|
for(unsigned int n = 0; n < p_.nS ; n++)
|
|
|
|
|
{
|
|
|
|
|
if(n > 0) stream << ",";
|
|
|
|
|
switch (p_.B_fetching_policy)
|
|
|
|
|
{
|
|
|
|
|
case FETCH_FROM_GLOBAL_CONTIGUOUS:
|
|
|
|
|
stream << "gidy*" << p_.nL << " + idy*" << p_.nS << " + " << n << " < N";
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
stream << "gidy*" << p_.nL << " + idy + " << n * p_.local_size_1 << " < N";
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
stream << "};" << std::endl;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//Bounds checking for A if Local
|
|
|
|
|
if (p_.A_fetching_policy==FETCH_FROM_LOCAL)
|
|
|
|
|
{
|
|
|
|
|
unsigned int fetch_size = (A_trans_=='N'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
|
|
|
|
|
stream << "bool in_bounds_A[" << p_.mL/fetch_size << "];" << std::endl;
|
|
|
|
|
stream << "for(unsigned int m = 0; m < " << p_.mL/fetch_size << "; m++)" << std::endl;
|
|
|
|
|
stream.inc_tab();
|
|
|
|
|
stream << "in_bounds_A[m] = (gidx*" << p_.mL << " + " << (A_trans_=='N'?"idxT":"idyT") << " + m*" << fetch_size << ") < M;" << std::endl;
|
|
|
|
|
stream.dec_tab();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//Bounds checking for B if Local
|
|
|
|
|
if (p_.B_fetching_policy==FETCH_FROM_LOCAL)
|
|
|
|
|
{
|
|
|
|
|
unsigned int fetch_size = (B_trans_=='T'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
|
|
|
|
|
stream << "bool in_bounds_B[" << p_.nL/fetch_size << "];" << std::endl;
|
|
|
|
|
stream << "for(unsigned int n = 0; n < " << p_.nL/fetch_size << "; n++)" << std::endl;
|
|
|
|
|
stream.inc_tab();
|
|
|
|
|
stream << "in_bounds_B[n] = (gidy*" << p_.nL << " + " << (B_trans_=='T'?"idxT":"idyT") << " + n*" << fetch_size << ") < N;" << std::endl;
|
|
|
|
|
stream.dec_tab();
|
|
|
|
|
|
|
|
|
|
// for(unsigned int n = 0 ; n < p_.nL/fetch_size ; n++)
|
|
|
|
|
// stream << n>0?",":"" << "(gidy*" << p_.nL << " + " << (B_trans_=='T'?"idxT":"idyT") << " + " << n*fetch_size << ") < N";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if(B_trans_=='T')
|
|
|
|
|
stream << "__global " << vdtype << "* Bi[" << npB << "] = {B};" << std::endl;
|
|
|
|
|
|
|
|
|
|
switch (p_.A_fetching_policy)
|
|
|
|
|
{
|
|
|
|
|
case FETCH_FROM_LOCAL:
|
|
|
|
|
for(unsigned int i = 0 ; i < npA ; i++ )
|
|
|
|
|
if (A_trans_=='N')
|
|
|
|
|
stream << "A += (gidx*" << p_.mL/p_.simd_width << " + idxT) " << ASTRIDE1 << " + idyT*Ald + offz*Ald;" << std::endl;
|
|
|
|
|
{
|
|
|
|
|
stream << "Ai[" << i << "] += (gidx*" << p_.mL/p_.simd_width << ") " << ASTRIDE1 << " + idyT*Ald + offz*Ald;" << std::endl;
|
|
|
|
|
stream << "if(gidx*" << p_.mL << " + idxT + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < M) Ai[" << i << "] += (idxT + " << i*p_.local_fetch_0 << ")" << ASTRIDE1 << ";" << std::endl;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
stream << "A += idxT" << ASTRIDE1 << " + gidx*" << p_.mL/p_.simd_width << "*Ald + idyT*Ald + offz;" << std::endl;
|
|
|
|
|
{
|
|
|
|
|
stream << "Ai[" << i << "] += idxT" << ASTRIDE1 << " + gidx*" << p_.mL/p_.simd_width << "*Ald + offz;" << std::endl;
|
|
|
|
|
stream << "if(gidx*" << p_.mL << " + idyT + " << i << "*" << p_.local_fetch_1 << " < M) Ai[" << i << "] += (idyT + " << i*p_.local_fetch_1 << ")*Ald;" << std::endl;
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
case FETCH_FROM_GLOBAL_CONTIGUOUS:
|
|
|
|
@@ -293,10 +246,18 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|
|
|
|
switch (p_.B_fetching_policy)
|
|
|
|
|
{
|
|
|
|
|
case FETCH_FROM_LOCAL:
|
|
|
|
|
for(unsigned int i = 0 ; i < npB ; i++ )
|
|
|
|
|
if (B_trans_=='T')
|
|
|
|
|
stream << "B += (gidy*" << p_.nL/p_.simd_width << " + idxT)" << BSTRIDE1 << " + idyT*Bld + offz*Bld;" << std::endl;
|
|
|
|
|
{
|
|
|
|
|
stream << "Bi[" << i << "] += (gidy*" << p_.nL/p_.simd_width << ")" << BSTRIDE1 << " + idyT*Bld + offz*Bld;" << std::endl;
|
|
|
|
|
stream << "if(gidy*" << p_.nL << " + idxT + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += (idxT + " << i*p_.local_fetch_0 << ")" << BSTRIDE1 << ";" << std::endl;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
stream << "B += idxT" << BSTRIDE1 << " + gidy*" << p_.nL/p_.simd_width << "*Bld + idyT*Bld + offz;" << std::endl;
|
|
|
|
|
{
|
|
|
|
|
stream << "Bi[" << i << "] += idxT" << BSTRIDE1 << " + gidy*" << p_.nL/p_.simd_width << "*Bld + offz;" << std::endl;
|
|
|
|
|
stream << "if(gidy*" << p_.nL << " + idyT + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += (idyT + " << i*p_.local_fetch_1 << ")*Bld;" << std::endl;
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
case FETCH_FROM_GLOBAL_CONTIGUOUS:
|
|
|
|
@@ -341,17 +302,15 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|
|
|
|
for(int_t k = 0; k < p_.kL; k += p_.local_fetch_1)
|
|
|
|
|
for(int_t m = 0; m < p_.mL; m += p_.local_fetch_0*p_.simd_width)
|
|
|
|
|
{
|
|
|
|
|
string in_bounds = "in_bounds_A[" + to_string(m/(p_.local_fetch_0*p_.simd_width)) + "] && (idyT + block_k < K)";
|
|
|
|
|
string to_load = "A[" + to_string(k) + "*Ald + " + to_string(m/p_.simd_width) + ASTRIDE1 + "]";
|
|
|
|
|
stream << VSTORE(HANDLE_BOUNDS(in_bounds, to_load), "0", "lAstore + lAstart + " + to_string(k*lAld+m)) << ";" << std::endl;
|
|
|
|
|
string to_load = "Ai[" + to_string(m/(p_.simd_width*p_.local_fetch_0)) +"][" + to_string(k) + "*Ald]";
|
|
|
|
|
stream << VSTORE(to_load, "0", "lAstore + lAstart + " + to_string(k*lAld+m)) << ";" << std::endl;
|
|
|
|
|
}
|
|
|
|
|
else if (p_.A_fetching_policy==FETCH_FROM_LOCAL && A_trans_=='T')
|
|
|
|
|
for(int_t k = 0; k < p_.mL; k += p_.local_fetch_1)
|
|
|
|
|
for(int_t m = 0; m < p_.kL; m += p_.local_fetch_0*p_.simd_width)
|
|
|
|
|
{
|
|
|
|
|
string in_bounds = "in_bounds_A[" + to_string(k/p_.local_fetch_1) + "] && (idxT + block_k < K)";
|
|
|
|
|
string to_load = "A[" + to_string(k) + "*Ald + " + to_string(m/p_.simd_width) + ASTRIDE1 + "]";
|
|
|
|
|
stream << VSTORE(HANDLE_BOUNDS(in_bounds, to_load), "0", "lAstore + lAstart + " + to_string(m*lAld+k)) << ";" << std::endl;
|
|
|
|
|
string to_load = "Ai[" + to_string(k) + "][" + to_string(m/p_.simd_width) + ASTRIDE1 + "]";
|
|
|
|
|
stream << VSTORE(to_load, "0", "lAstore + lAstart + " + to_string(m*lAld+k)) << ";" << std::endl;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
stream << "//Fetch B to local memory" << std::endl;
|
|
|
|
@@ -359,17 +318,15 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|
|
|
|
for(int_t k = 0; k < p_.kL; k += p_.local_fetch_1)
|
|
|
|
|
for(int_t n = 0; n < p_.nL; n += p_.local_fetch_0*p_.simd_width)
|
|
|
|
|
{
|
|
|
|
|
string in_bounds = "in_bounds_B[" + to_string(n/(p_.local_fetch_0*p_.simd_width)) + "] && (idyT + block_k < K)";
|
|
|
|
|
string to_load = "B[" + to_string(k) + "*Bld + " + to_string(n/p_.simd_width) + BSTRIDE1 + "]";
|
|
|
|
|
stream << VSTORE(HANDLE_BOUNDS(in_bounds, to_load), "0", "lBstore + lBstart + " + to_string(k*lBld+n)) << ";" << std::endl;
|
|
|
|
|
string to_load = "Bi[" + to_string(n/(p_.local_fetch_0*p_.simd_width)) + "][" + to_string(k) + "*Bld]";
|
|
|
|
|
stream << VSTORE(to_load, "0", "lBstore + lBstart + " + to_string(k*lBld+n)) << ";" << std::endl;
|
|
|
|
|
}
|
|
|
|
|
else if (p_.B_fetching_policy==FETCH_FROM_LOCAL && B_trans_=='N')
|
|
|
|
|
for(int_t k = 0; k < p_.nL; k += p_.local_fetch_1)
|
|
|
|
|
for(int_t n = 0; n < p_.kL; n += p_.local_fetch_0*p_.simd_width)
|
|
|
|
|
{
|
|
|
|
|
string in_bounds = "in_bounds_B[" + to_string(k/p_.local_fetch_1) + "] && (idxT + block_k < K)";
|
|
|
|
|
string to_load = "B[" + to_string(k) + "*Bld + " + to_string(n/p_.simd_width) + BSTRIDE1 + "]";
|
|
|
|
|
stream << VSTORE(HANDLE_BOUNDS(in_bounds, to_load), "0", "lBstore + lBstart + " + to_string(n*lBld+k)) << ";" << std::endl;
|
|
|
|
|
string to_load = "Bi[" + to_string(k) + "][" + to_string(n/p_.simd_width) + BSTRIDE1 + "]";
|
|
|
|
|
stream << VSTORE(to_load, "0", "lBstore + lBstart + " + to_string(n*lBld+k)) << ";" << std::endl;
|
|
|
|
|
}
|
|
|
|
|
stream << LocalBarrier(backend) << ";" << std::endl;
|
|
|
|
|
}
|
|
|
|
@@ -380,7 +337,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stream << "//Inner loop" << std::endl;
|
|
|
|
|
stream << "for(unsigned int k = 0; k < " << p_.kL << " && (block_k + k < chunk_size); k+=" << p_.kS << "){" << std::endl;
|
|
|
|
|
stream << "for(unsigned int k = 0; k < " << p_.kL << "; k+=" << p_.kS << "){" << std::endl;
|
|
|
|
|
stream.inc_tab();
|
|
|
|
|
|
|
|
|
|
stream << "//Fetch A to registers" << std::endl;
|
|
|
|
@@ -480,17 +437,21 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|
|
|
|
if (p_.A_fetching_policy==FETCH_FROM_LOCAL)
|
|
|
|
|
{
|
|
|
|
|
if (A_trans_=='N')
|
|
|
|
|
stream << "A += " << p_.kL << "*Ald;" << std::endl;
|
|
|
|
|
for(unsigned int i = 0 ; i < npA ; ++i)
|
|
|
|
|
stream << "Ai[" << i << "] += " << p_.kL << "*Ald;" << std::endl;
|
|
|
|
|
else
|
|
|
|
|
stream << "A += " << p_.kL << ASTRIDE1 << ";" << std::endl;
|
|
|
|
|
for(unsigned int i = 0 ; i < npA ; ++i)
|
|
|
|
|
stream << "Ai[" << i << "] += " << p_.kL << ASTRIDE1 << ";" << std::endl;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (p_.B_fetching_policy==FETCH_FROM_LOCAL)
|
|
|
|
|
{
|
|
|
|
|
if (B_trans_=='T')
|
|
|
|
|
stream << "B += " << p_.kL << "*Bld;" << std::endl;
|
|
|
|
|
for(unsigned int i = 0 ; i < npB ; ++i)
|
|
|
|
|
stream << "Bi[" << i << "] += " << p_.kL << "*Bld;" << std::endl;
|
|
|
|
|
else
|
|
|
|
|
stream << "B += " << p_.kL << BSTRIDE1 << ";" << std::endl;
|
|
|
|
|
for(unsigned int i = 0 ; i < npB ; ++i)
|
|
|
|
|
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
stream.dec_tab();
|
|
|
|
@@ -552,7 +513,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|
|
|
|
stream << "}" << std::endl;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// std::cout << stream.str() << std::endl;
|
|
|
|
|
std::cout << stream.str() << std::endl;
|
|
|
|
|
return stream.str();
|
|
|
|
|
|
|
|
|
|
#undef HANDLE_BOUNDS
|
|
|
|
@@ -738,7 +699,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|
|
|
|
, int_t ms, int_t ks, int_t ns
|
|
|
|
|
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
|
|
|
|
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
|
|
|
|
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N')
|
|
|
|
|
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), true, 'N', 'N')
|
|
|
|
|
{ }
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
@@ -747,7 +708,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|
|
|
|
, int_t ms, int_t ks, int_t ns
|
|
|
|
|
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
|
|
|
|
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
|
|
|
|
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N')
|
|
|
|
|
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), true, 'T', 'N')
|
|
|
|
|
{ }
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
@@ -756,7 +717,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|
|
|
|
, int_t ms, int_t ks, int_t ns
|
|
|
|
|
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
|
|
|
|
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
|
|
|
|
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T')
|
|
|
|
|
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), true, 'N', 'T')
|
|
|
|
|
{ }
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
@@ -765,7 +726,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|
|
|
|
, int_t ms, int_t ks, int_t ns
|
|
|
|
|
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
|
|
|
|
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
|
|
|
|
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T')
|
|
|
|
|
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), true, 'T', 'T')
|
|
|
|
|
{ }
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|