Cleaning GEMM test

This commit is contained in:
Philippe Tillet
2015-07-18 13:09:38 -04:00
parent 54ad83f4a6
commit ab82a9c048
2 changed files with 33 additions and 32 deletions

View File

@@ -168,10 +168,12 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "B += offb;" << std::endl;
stream << "C += offc;" << std::endl;
stream << "size_t gidx = " << GroupIdx0(backend) << ";" << std::endl;
stream << "size_t gidy = " << GroupIdx1(backend) << ";" << std::endl;
stream << "size_t idx = " << LocalIdx0(backend) << ";" << std::endl;
stream << "size_t idy = " << LocalIdx1(backend) << ";" << std::endl;
stream << "int4 ids = (int4)(" << GroupIdx0(backend) << "," << GroupIdx1(backend) << "," << LocalIdx0(backend) << "," << LocalIdx1(backend) << ");" << std::endl;
stream << "size_t idt = " << p_.local_size_0 << "*ids.w + ids.z;" << std::endl;
stream << "int2 idT;" << std::endl;
stream << "idT.y = idt/" << p_.local_fetch_0 << ";" << std::endl;
stream << "idT.x = idt - " << p_.local_fetch_0 << "*idT.y;" << std::endl;
if(has_depth)
{
@@ -181,39 +183,38 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "K = min(K - div*gidz, div);" << std::endl;
}
stream << std::endl;
stream << "size_t idt = " << p_.local_size_0 << "*idy + idx;" << std::endl;
stream << "size_t idxT = idt % " << p_.local_fetch_0 << ";" << std::endl;
stream << "size_t idyT = idt / " << p_.local_fetch_0 << ";" << std::endl;
stream << std::endl;
stream << "ids.x *= " << p_.mL << ";" << std::endl;
stream << "ids.y *= " << p_.nL << ";" << std::endl;
stream << "idT.x *= " << p_.simd_width << ";" << std::endl;
if (A_trans_=='N')
stream << "A += (idxT*" << p_.simd_width << " + gidx*" << p_.mL<< ")" << ASTRIDE1 << " + idyT*lda" << (has_depth?"+ offz*lda":"") << ";" << std::endl;
stream << "A += (idT.x + ids.x)" << ASTRIDE1 << " + idT.y*lda" << (has_depth?"+ offz*lda":"") << ";" << std::endl;
else
stream << "A += idxT*" << p_.simd_width << ASTRIDE1 << " + (idyT + gidx*" << p_.mL/p_.simd_width << ")*lda" << (has_depth?"+ offz":"") << ";" << std::endl;
stream << "A += idT.x" << ASTRIDE1 << " + idT.y*lda + ids.x*lda" << (has_depth?"+ offz":"") << ";" << std::endl;
if(B_trans_=='T')
stream << "B += (idxT*" << p_.simd_width << " + gidy*" << p_.nL << ")" << BSTRIDE1 << " + idyT*ldb" << (has_depth?"+ offz*ldb":"") << ";" << std::endl;
stream << "B += (idT.x + ids.y)" << BSTRIDE1 << " + idT.y*ldb" << (has_depth?"+ offz*ldb":"") << ";" << std::endl;
else
stream << "B += idxT*" << p_.simd_width << BSTRIDE1 << " + (idyT + gidy*" << p_.nL << ")*ldb" << (has_depth?"+ offz":"") << ";" << std::endl;
stream << "B += idT.x" << BSTRIDE1 << " + idT.y*ldb + ids.y*ldb" << (has_depth?"+ offz":"") << ";" << std::endl;
stream << "for(unsigned int i = 0 ; i < " << npA << " ; ++i) Ai[i] = A;" << std::endl;
stream << "for(unsigned int i = 0 ; i < " << npB << " ; ++i) Bi[i] = B;" << std::endl;
for(unsigned int i = 0 ; i < npA ; i++ )
if (A_trans_=='N')
stream << "if(gidx*" << p_.mL << " + idxT*" << p_.simd_width << " + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < M) Ai[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << ASTRIDE1 << ";" << std::endl;
stream << "if(ids.x + idT.x + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < M) Ai[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << ASTRIDE1 << ";" << std::endl;
else
stream << "if(gidx*" << p_.mL << " + idyT + " << i << "*" << p_.local_fetch_1 << " < M) Ai[" << i << "] += " << i*p_.local_fetch_1 << "*lda;" << std::endl;
stream << "if(ids.x + idT.y + " << i << "*" << p_.local_fetch_1 << " < M) Ai[" << i << "] += " << i*p_.local_fetch_1 << "*lda;" << std::endl;
for(unsigned int i = 0 ; i < npB ; i++ )
if (B_trans_=='T')
stream << "if(gidy*" << p_.nL << " + idxT* " << p_.simd_width << " + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << BSTRIDE1 << ";" << std::endl;
stream << "if(ids.y + idT.x + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << BSTRIDE1 << ";" << std::endl;
else
stream << "if(gidy*" << p_.nL << " + idyT + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += " << i*p_.local_fetch_1 << "*ldb;" << std::endl;
stream << "if(ids.y + idT.y + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += " << i*p_.local_fetch_1 << "*ldb;" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* lAstore = lA + idyT*" << llda << " + idxT*" << p_.simd_width << ";" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* lBstore = lB + idyT*" << lldb << " + idxT*" << p_.simd_width << ";" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* lAstore = lA + idT.y*" << llda << " + idT.x;" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* lBstore = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
stream << "//Outer loop" << std::endl;
stream << "for(long block_k=K; block_k > 0 ; block_k-=" << p_.kL << "){" << std::endl;
@@ -229,7 +230,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k);
string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]");
to_load = "(idyT + " + kk + "< block_k)?" + to_load + ":0";
to_load = "(idT.y + " + kk + "< block_k)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lAstore + " + to_string(k*llda+m)) << ";" << std::endl;
}
}
@@ -241,7 +242,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string mm = to_string(m/p_.local_fetch_1);
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]");
to_load = "(idxT + " + kk + "< block_k)?" + to_load + ":0";
to_load = "(idT.x + " + kk + "< block_k)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lAstore + " + to_string(m*llda+k)) << ";" << std::endl;
}
}
@@ -255,7 +256,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*ldb]");
to_load = "(idyT + " + kk + "< block_k)?" + to_load + ":0";
to_load = "(idT.y + " + kk + "< block_k)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lBstore + " + to_string(k*lldb+n)) << ";" << std::endl;
}
}
@@ -267,21 +268,21 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string nn = to_string(n/p_.local_fetch_1);
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]");
to_load = "(idxT + " + kk + "< block_k)?" + to_load + ":0";
to_load = "(idT.x + " + kk + "< block_k)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lBstore + " + to_string(n*lldb+k)) << ";" << std::endl;
}
}
stream << LocalBarrier(backend) << ";" << std::endl;
if(A_trans_=='N')
stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + idx*" << p_.simd_width << ";" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + ids.z*" << p_.simd_width << ";" << std::endl;
else
stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + idx*" << llda*p_.simd_width << ";" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + ids.z*" << llda*p_.simd_width << ";" << std::endl;
if(B_trans_=='T')
stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + idy*" << p_.simd_width << ";" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + ids.w*" << p_.simd_width << ";" << std::endl;
else
stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + idy*" << lldb*p_.simd_width << ";" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + ids.w*" << lldb*p_.simd_width << ";" << std::endl;
stream << "//Inner loop" << std::endl;
@@ -371,8 +372,8 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "//Write back C" << std::endl;
stream << "size_t offx = (gidx*" << p_.mL << " + idx*" << p_.simd_width << ")" << ";" << std::endl;
stream << "size_t offy = (gidy*" << p_.nL << " + idy*" << p_.simd_width << ");" << std::endl;
stream << "size_t offx = (ids.x + ids.z*" << p_.simd_width << ")" << ";" << std::endl;
stream << "size_t offy = (ids.y + ids.w*" << p_.simd_width << ");" << std::endl;
stream << "C += " << "offx" << CSTRIDE1 << " + offy*ldc" << (has_depth?" + gidz*ldc*N;":"") << ";" << std::endl;
stream << "N -= offy;" << std::endl;
stream << "M -= offx;" << std::endl;