More cleaning

This commit is contained in:
Philippe Tillet
2016-10-02 20:21:38 -04:00
parent 77178d7017
commit a26582d34b
20 changed files with 214 additions and 220 deletions

View File

@@ -37,11 +37,11 @@ namespace isaac
namespace templates
{
gemm_parameters::gemm_parameters(unsigned int vwidth
, unsigned int ls0, unsigned int KL, unsigned int ls1, unsigned int D
, unsigned int ms, unsigned int ks, unsigned int ns
, fetch_type Afetch, fetch_type Bfetch
, unsigned int lf0, unsigned int lf1): base::parameters_type(vwidth, ls0, ls1, 1),
gemm_parameters::gemm_parameters(uint32_t vwidth
,uint32_t ls0, uint32_t KL, uint32_t ls1, uint32_t D
,uint32_t ms, uint32_t ks, uint32_t ns
,fetch_type Afetch, fetch_type Bfetch
,uint32_t lf0, uint32_t lf1): base::parameters_type(vwidth, ls0, ls1, 1),
kL(KL), depth(D), mS(ms), kS(ks), nS(ns), Afetch(Afetch), Bfetch(Bfetch),
lf0(lf0), lf1(lf1),
mL(ms*ls0), nL(ns*ls1)
@@ -49,9 +49,9 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
}
unsigned int gemm::lmem_usage(expression_tree const & expression) const
uint32_t gemm::lmem_usage(expression_tree const & expression) const
{
unsigned int N = 0;
uint32_t N = 0;
size_t llda = (A_trans_=='N')?p_.mL:p_.kL+1;
size_t lnda = (A_trans_=='N')?p_.kL:p_.mL;
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL+1;
@@ -61,13 +61,13 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
return N*size_of(expression.dtype());
}
unsigned int gemm::registers_usage(expression_tree const & expression) const
uint32_t gemm::registers_usage(expression_tree const & expression) const
{
unsigned int N = p_.mS * p_.nS + p_.mS * p_.kS + p_.kS * p_.nS;
uint32_t N = p_.mS * p_.nS + p_.mS * p_.kS + p_.kS * p_.nS;
return N*size_of(expression.dtype());
}
unsigned int gemm::temporary_workspace(expression_tree const & expressions) const
uint32_t gemm::temporary_workspace(expression_tree const & expressions) const
{
std::vector<int_t> MNK = input_sizes(expressions);
int_t M = MNK[0]; int_t N = MNK[1];
@@ -97,8 +97,8 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
if (p_.Afetch==FETCH_FROM_LOCAL)
{
unsigned int bound1 = (A_trans_=='N')?p_.kL:p_.mL;
unsigned int bound0 = (A_trans_=='N')?p_.mL:p_.kL;
uint32_t bound1 = (A_trans_=='N')?p_.kL:p_.mL;
uint32_t bound0 = (A_trans_=='N')?p_.mL:p_.kL;
if (p_.lf1>0 && (bound1 % p_.lf1)> 0)
return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
@@ -109,8 +109,8 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
}
if (p_.Bfetch==FETCH_FROM_LOCAL)
{
unsigned int bound1 = (B_trans_=='T')?p_.kL:p_.nL;
unsigned int bound0 = (B_trans_=='T')?p_.nL:p_.kL;
uint32_t bound1 = (B_trans_=='T')?p_.kL:p_.nL;
uint32_t bound0 = (B_trans_=='T')?p_.nL:p_.kL;
if (p_.lf1>0 && (bound1 % p_.lf1)> 0)
return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
@@ -190,8 +190,8 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
size_t lndb = (B_trans_=='T')?p_.kL:p_.nL;
stream << "$LOCAL " << sdtype << " lA[" << llda*lnda << "];" << std::endl;
stream << "$LOCAL " << sdtype << " lB[" << lldb*lndb << "];" << std::endl;
unsigned int npA = p_.mL/(A_trans_=='N'?p_.lf0*p_.vwidth:p_.lf1);
unsigned int npB = p_.nL/(B_trans_=='T'?p_.lf0*p_.vwidth:p_.lf1);
uint32_t npA = p_.mL/(A_trans_=='N'?p_.lf0*p_.vwidth:p_.lf1);
uint32_t npB = p_.nL/(B_trans_=='T'?p_.lf0*p_.vwidth:p_.lf1);
stream << "$GLOBAL " << sdtype << "* Ai[" << npA << "];" << std::endl;
stream << "$GLOBAL " << sdtype << "* Bi[" << npB << "];" << std::endl;
stream << std::endl;
@@ -290,13 +290,13 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
stream << "}" << std::endl;
stream << std::endl;
for(unsigned int i = 0 ; i < npA ; i++ )
for(uint32_t i = 0 ; i < npA ; i++ )
if (A_trans_=='N')
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.lf0*p_.vwidth) + " < M", "(int)((idT.x + " + to_string(i*p_.lf0*p_.vwidth) + ")" + ASTRIDE1 + ")", "0") << ";" << std::endl;
else
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.lf1) + " < M", "(int)((idT.y + " + to_string(i*p_.lf1) + ")*lda)", "0") << ";" << std::endl;
for(unsigned int i = 0 ; i < npB ; i++ )
for(uint32_t i = 0 ; i < npB ; i++ )
if (B_trans_=='T')
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.lf0*p_.vwidth) + " < N", "(int)((idT.x + " + to_string(i*p_.lf0*p_.vwidth) + ")" + BSTRIDE1 + ")", "0") << ";" << std::endl;
else
@@ -318,13 +318,13 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
stream << "//Fetch A to local memory" << std::endl;
if (A_trans_=='N')
{
for(unsigned int k = 0; k < p_.kL; k += p_.lf1)
for(unsigned int m = 0; m < p_.mL; m += p_.lf0*p_.vwidth)
for(uint32_t k = 0; k < p_.kL; k += p_.lf1)
for(uint32_t m = 0; m < p_.mL; m += p_.lf0*p_.vwidth)
{
std::string mm = to_string(m/(p_.vwidth*p_.lf0));
std::string kk = to_string(k);
if(last_iteration)
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
stream << "ldsA[" << k*llda + m + s << "] = (condy" << k << " && " << s << "< M)? Ai[" << mm << "][" << k << "*lda + " << s << "] : 0;" << std::endl;
else
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Ai[" + mm +"][" + kk + "*lda]"), "0", "ldsA + " + to_string(k*llda+m)) << ";" << std::endl;
@@ -332,13 +332,13 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
}
else
{
for(unsigned int k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
for(unsigned int m = 0; m < p_.mL; m += p_.lf1)
for(uint32_t k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
for(uint32_t m = 0; m < p_.mL; m += p_.lf1)
{
std::string mm = to_string(m/p_.lf1);
std::string kk = to_string(k);
if(last_iteration)
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
stream << "ldsA[" << m*llda + k + s << "] = condx" << k + s << "? Ai[" << mm << "][" << k + s << ASTRIDE1 << "] : 0;" << std::endl;
else
@@ -349,13 +349,13 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
stream << "//Fetch B to local memory" << std::endl;
if (B_trans_=='T')
{
for(unsigned int k = 0; k < p_.kL; k += p_.lf1)
for(unsigned int n = 0; n < p_.nL; n += p_.lf0*p_.vwidth)
for(uint32_t k = 0; k < p_.kL; k += p_.lf1)
for(uint32_t n = 0; n < p_.nL; n += p_.lf0*p_.vwidth)
{
std::string nn = to_string(n/(p_.vwidth*p_.lf0));
std::string kk = to_string(k);
if(last_iteration)
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
stream << "ldsB[" << k*lldb + n + s << "] = (condy" << k << " && " << s << "< N)? Bi[" << nn << "][" << kk << "*ldb +" << s << "] : 0;" << std::endl;
else
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Bi[" + nn +"][" + kk + "*ldb]"), "0", "ldsB + " + to_string(k*lldb+n)) << ";" << std::endl;
@@ -363,13 +363,13 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
}
else
{
for(unsigned int k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
for(unsigned int n = 0; n < p_.nL; n += p_.lf1)
for(uint32_t k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
for(uint32_t n = 0; n < p_.nL; n += p_.lf1)
{
std::string nn = to_string(n/p_.lf1);
std::string kk = to_string(k);
if(last_iteration)
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
stream << "ldsB[" << n*lldb + k + s << "] = condx" << k + s << "? Bi[" << nn << "][" << k + s << BSTRIDE1 << "] : 0;" << std::endl;
else
@@ -391,14 +391,14 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
std::string bound = last_iteration?"K":tools::to_string(p_.kL);
size_t ks = last_iteration?1:p_.kS;
stream << "//Inner loop" << std::endl;
stream << "for(unsigned int k = 0; k < " << bound << "; k+=" << ks << "){" << std::endl;
stream << "for(uint32_t k = 0; k < " << bound << "; k+=" << ks << "){" << std::endl;
stream.inc_tab();
stream << "//Fetch A to registers" << std::endl;
stream << "#pragma unroll" << std::endl;
stream << "for(unsigned int kk = 0; kk < " << ks << "; kk++)" << std::endl;
stream << "for(uint32_t kk = 0; kk < " << ks << "; kk++)" << std::endl;
stream << "#pragma unroll " << p_.mS/p_.vwidth << std::endl;
stream << "for(unsigned int mm = 0; mm < " << p_.mS/p_.vwidth << "; mm++)" << std::endl;
stream << "for(uint32_t mm = 0; mm < " << p_.mS/p_.vwidth << "; mm++)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
if(A_trans_=='N')
@@ -408,7 +408,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
if(p_.vwidth==1)
stream << "rA[kk][mm] = ldsA[k + mm*" << p_.ls0*llda << "+ kk" << "];" << std::endl;
else
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
stream << access_vector_type("rA[kk][mm]", s) << " = ldsA[k + (mm*" << p_.vwidth*p_.ls0 << " + " << s << ")*" << llda << "+ kk];" << std::endl;
}
@@ -417,9 +417,9 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
stream << "//Fetch B to registers" << std::endl;
stream << "#pragma unroll " << ks << std::endl;
stream << "for(unsigned int kk = 0; kk < " << ks << "; kk++)" << std::endl;
stream << "for(uint32_t kk = 0; kk < " << ks << "; kk++)" << std::endl;
stream << "#pragma unroll " << p_.nS/p_.vwidth << std::endl;
stream << "for(unsigned int nn = 0; nn < " << p_.nS/p_.vwidth << "; nn++)" << std::endl;
stream << "for(uint32_t nn = 0; nn < " << p_.nS/p_.vwidth << "; nn++)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
if(B_trans_=='T')
@@ -429,7 +429,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
if(p_.vwidth==1)
stream << "rB[kk][nn] = ldsB[k" << " + nn*" << p_.ls1*lldb << "+ kk" << "];" << std::endl;
else
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
stream << access_vector_type("rB[kk][nn]", s) << " = ldsB[k" << " + (nn*" << p_.vwidth*p_.ls1 << " + " << s << ")*" << lldb << "+ kk];" << std::endl;
}
stream.dec_tab();
@@ -437,10 +437,10 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
stream << "//FMA computations" << std::endl;
stream << "#pragma unroll" << std::endl;
stream << "for(unsigned int kk = 0 ; kk < " << ks << "; ++kk){" << std::endl;
stream << "for(uint32_t kk = 0 ; kk < " << ks << "; ++kk){" << std::endl;
stream.inc_tab();
for(unsigned int nn=0; nn < p_.nS; ++nn)
for(unsigned int mm=0; mm < p_.mS; ++mm){
for(uint32_t nn=0; nn < p_.nS; ++nn)
for(uint32_t mm=0; mm < p_.mS; ++mm){
string res_str, lhs_str, rhs_str;
res_str = "rC[" + to_string(mm) + "][" + to_string(nn) + "]";
if (p_.vwidth==1)
@@ -461,18 +461,18 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
//Increment A pointers to global memory
if (A_trans_=='N')
for(unsigned int i = 0 ; i < npA ; ++i)
for(uint32_t i = 0 ; i < npA ; ++i)
stream << "Ai[" << i << "] += " << p_.kL << "*lda;" << std::endl;
else
for(unsigned int i = 0 ; i < npA ; ++i)
for(uint32_t i = 0 ; i < npA ; ++i)
stream << "Ai[" << i << "] += " << p_.kL << ASTRIDE1 << ";" << std::endl;
//Increment B pointers to global memory
if (B_trans_=='T')
for(unsigned int i = 0 ; i < npB ; ++i)
for(uint32_t i = 0 ; i < npB ; ++i)
stream << "Bi[" << i << "] += " << p_.kL << "*ldb;" << std::endl;
else
for(unsigned int i = 0 ; i < npB ; ++i)
for(uint32_t i = 0 ; i < npB ; ++i)
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
};
fetch_to_lds(false);
@@ -483,15 +483,15 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
if(A_trans_=='N' || B_trans_=='T')
{
stream << "int Ky = K - idT.y;" << std::endl;
for(unsigned int k = 0; k < p_.kL; k += p_.lf1)
for(uint32_t k = 0; k < p_.kL; k += p_.lf1)
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
}
if(A_trans_=='T' || B_trans_=='N')
{
stream << "int Kx = K - idT.x;" << std::endl;
for(unsigned int k = 0 ; k < p_.kL ; k += p_.lf0*p_.vwidth)
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
for(uint32_t k = 0 ; k < p_.kL ; k += p_.lf0*p_.vwidth)
for(uint32_t s = 0 ; s < p_.vwidth ; ++s)
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
}
fetch_to_lds(true);
@@ -522,13 +522,13 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
stream << "N -= ids.y;" << std::endl;
stream << "N -= ids.w*" << p_.vwidth << ";" << std::endl;
for(unsigned int n=0; n < p_.nS; ++n)
for(uint32_t n=0; n < p_.nS; ++n)
{
string Cj = to_string((n/p_.vwidth)*(p_.ls1*p_.vwidth) + n%p_.vwidth);
stream << "if(" << Cj << " >= N) return;" << std::endl;
for(unsigned int m=0; m < p_.mS; ++m)
for(uint32_t m=0; m < p_.mS; ++m)
stream << "rC[" << m << "][" << n << "] *= alpha;" << std::endl;
for(unsigned int m=0; m < p_.mS; ++m)
for(uint32_t m=0; m < p_.mS; ++m)
{
string Ci = to_string((m/p_.vwidth)*(p_.ls0*p_.vwidth) + m%p_.vwidth);
stream << "if(" << Ci << "< M) ";
@@ -560,14 +560,14 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
stream.inc_tab();
stream << "C += Cstart;" << std::endl;
stream << "for(unsigned int i = $GLOBAL_IDX_0 ; i < M ; i += $GLOBAL_SIZE_0)" << std::endl;
stream << "for(uint32_t i = $GLOBAL_IDX_0 ; i < M ; i += $GLOBAL_SIZE_0)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
stream << "for(unsigned int j = $GLOBAL_IDX_1 ; j < N ; j += $GLOBAL_SIZE_1)" << std::endl;
stream << "for(uint32_t j = $GLOBAL_IDX_1 ; j < N ; j += $GLOBAL_SIZE_1)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
stream << sdtype << " acc = 0;" << std::endl;
stream << "for(unsigned int k = 0 ; k < D ; k++)" << std::endl;
stream << "for(uint32_t k = 0 ; k < D ; k++)" << std::endl;
stream.inc_tab();
stream << "acc += Z[i + j*Zld + k*Zld*N];" << std::endl;
stream.dec_tab();
@@ -609,7 +609,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
driver::NDRange local(p_.ls0, p_.ls1, 1);
driver::NDRange global(align(align(M,p_.mS)/p_.mS, p_.ls0), align(align(N,p_.nS)/p_.nS, p_.ls1), p_.depth);
unsigned int current_arg = 0;
uint32_t current_arg = 0;
driver::Buffer& workspace = driver::backend::workspaces::get(options.queue(queue.context()));
gemm.setSizeArg(current_arg++, M);
@@ -656,7 +656,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
if(p_.depth > 1)
{
unsigned int current_arg = 0;
uint32_t current_arg = 0;
driver::Kernel reduce(program, reduce_name.c_str());
driver::NDRange local(p_.ls0, p_.ls1);
driver::NDRange global(align(M, p_.ls0), align(N, p_.ls1));
@@ -721,7 +721,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
}
//
gemm_nn::gemm_nn(unsigned int simd
gemm_nn::gemm_nn(uint32_t simd
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetch_type Afetch , fetch_type Bfetch
@@ -731,7 +731,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
}
//
gemm_tn::gemm_tn(unsigned int simd
gemm_tn::gemm_tn(uint32_t simd
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetch_type Afetch , fetch_type Bfetch
@@ -740,7 +740,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
{ }
//
gemm_nt::gemm_nt(unsigned int simd
gemm_nt::gemm_nt(uint32_t simd
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetch_type Afetch , fetch_type Bfetch
@@ -749,7 +749,7 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
{ }
//
gemm_tt::gemm_tt(unsigned int simd
gemm_tt::gemm_tt(uint32_t simd
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetch_type Afetch , fetch_type Bfetch