GEMM: Improved performance for cases other than NT
This commit is contained in:
@@ -122,9 +122,6 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
std::string vdtype = append_width(sdtype, p_.simd_width);
|
||||
std::string _size_t = size_type(device);
|
||||
|
||||
size_t lAld = p_.mL;
|
||||
size_t lBld = p_.nL;
|
||||
|
||||
//////////////////
|
||||
/// DECLARATIONS
|
||||
/// //////////////
|
||||
@@ -161,8 +158,10 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
stream << vdtype << " rB[" << p_.kS << "][" << p_.nS/p_.simd_width << "];" << std::endl;
|
||||
|
||||
///Result Values
|
||||
stream << Local(backend) << " " << sdtype << " lA[" << p_.kL*lAld << "];" << std::endl;
|
||||
stream << Local(backend) << " " << sdtype << " lB[" << p_.kL*lBld << "];" << std::endl;
|
||||
size_t lAld = (A_trans_=='N')?p_.mL:p_.kL;
|
||||
stream << Local(backend) << " " << sdtype << " lA[" << p_.kL*p_.mL << "];" << std::endl;
|
||||
size_t lBld = (B_trans_=='T')?p_.nL:p_.kL;
|
||||
stream << Local(backend) << " " << sdtype << " lB[" << p_.kL*p_.nL << "];" << std::endl;
|
||||
stream << std::endl;
|
||||
|
||||
stream << "size_t gidx = " << GroupIdx0(backend) << ";" << std::endl;
|
||||
@@ -194,12 +193,12 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
if (A_trans_=='N')
|
||||
stream << "A += (idxT*" << p_.simd_width << " + gidx*" << p_.mL<< ")" << ASTRIDE1 << " + idyT*Ald + offz*Ald;" << std::endl;
|
||||
else
|
||||
stream << "A += (idxT)*" << p_.simd_width << ASTRIDE1 << " + (idyT + gidx*" << p_.mL/p_.simd_width << ")*Ald + offz;" << std::endl;
|
||||
stream << "A += idxT*" << p_.simd_width << ASTRIDE1 << " + (idyT + gidx*" << p_.mL/p_.simd_width << ")*Ald + offz;" << std::endl;
|
||||
|
||||
if(B_trans_=='T')
|
||||
stream << "B += (idxT*" << p_.simd_width << " + gidy*" << p_.nL << ")" << BSTRIDE1 << " + idyT*Bld + offz*Bld;" << std::endl;
|
||||
else
|
||||
stream << "B += (idxT)*" << p_.simd_width << BSTRIDE1 << " + (idyT + gidy*" << p_.nL << ")*Bld + offz;" << std::endl;
|
||||
stream << "B += idxT*" << p_.simd_width << BSTRIDE1 << " + (idyT + gidy*" << p_.nL << ")*Bld + offz;" << std::endl;
|
||||
|
||||
|
||||
stream << "__global " << sdtype << "* Ai[" << npA << "];" << std::endl;
|
||||
@@ -222,19 +221,8 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
else
|
||||
stream << "if(gidy*" << p_.nL << " + idyT + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += " << i*p_.local_fetch_1 << "*Bld;" << std::endl;
|
||||
|
||||
|
||||
|
||||
if (A_trans_=='N')
|
||||
stream << LocalPtr(backend) << " " << sdtype << "* lAstore = lA + idyT*" << lAld << " + idxT*" << p_.simd_width << ";" << std::endl;
|
||||
else
|
||||
stream << LocalPtr(backend) << " " << sdtype << "* lAstore = lA + idxT*" << lAld*p_.simd_width << " + idyT;" << std::endl;
|
||||
|
||||
|
||||
if (B_trans_=='T')
|
||||
stream << LocalPtr(backend) << " " << sdtype << "* lBstore = lB + idyT*" << lBld << " + idxT*" << p_.simd_width << ";" << std::endl;
|
||||
else if (B_trans_=='N')
|
||||
stream << LocalPtr(backend) << " " << sdtype << "* lBstore = lB + idxT*" << lBld*p_.simd_width << " + idyT;" << std::endl;
|
||||
|
||||
|
||||
stream << "//Outer loop" << std::endl;
|
||||
stream << "for(size_t block_k=0; block_k < chunk_size ; block_k+=" << p_.kL << "){" << std::endl;
|
||||
@@ -255,23 +243,15 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
stream << VSTORE(to_load, "0", "lAstore + " + to_string(k*lAld+m)) << ";" << std::endl;
|
||||
}
|
||||
else
|
||||
for(int_t k = 0; k < p_.mL; k += p_.local_fetch_1)
|
||||
for(int_t m = 0; m < p_.kL; m += p_.local_fetch_0*p_.simd_width)
|
||||
for(int_t k = 0; k < p_.kL; k += p_.local_fetch_0*p_.simd_width)
|
||||
for(int_t m = 0; m < p_.mL; m += p_.local_fetch_1)
|
||||
{
|
||||
std::string mm = to_string(k/p_.local_fetch_1);
|
||||
std::string kk = to_string(m/p_.simd_width);
|
||||
std::string mm = to_string(m/p_.local_fetch_1);
|
||||
std::string kk = to_string(k);
|
||||
string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]");
|
||||
if(check_bounds_)
|
||||
to_load = "(block_k + idxT + " + kk + "< chunk_size)?" + to_load + ":0";
|
||||
if(p_.simd_width==1)
|
||||
stream << VSTORE(to_load, "0", "lAstore + " + to_string(m*lAld+k)) << ";" << std::endl;
|
||||
else
|
||||
{
|
||||
stream << vdtype << " tmpA" << k << "_" << m << " = " << to_load << ";" << std::endl;
|
||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||
stream << "lAstore[" << k + (m + s)*lAld << "]= tmpA" << k << "_" << m << ".s" << s << ";" << std::endl;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
stream << "//Fetch B to local memory" << std::endl;
|
||||
@@ -287,27 +267,27 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
stream << VSTORE(to_load, "0", "lBstore + " + to_string(k*lBld+n)) << ";" << std::endl;
|
||||
}
|
||||
else
|
||||
for(int_t k = 0; k < p_.nL; k += p_.local_fetch_1)
|
||||
for(int_t n = 0; n < p_.kL; n += p_.local_fetch_0*p_.simd_width)
|
||||
for(int_t k = 0; k < p_.kL; k += p_.local_fetch_0*p_.simd_width)
|
||||
for(int_t n = 0; n < p_.nL; n += p_.local_fetch_1)
|
||||
{
|
||||
std::string nn = to_string(k/p_.local_fetch_1);
|
||||
std::string kk = to_string(n);
|
||||
std::string nn = to_string(n/p_.local_fetch_1);
|
||||
std::string kk = to_string(k);
|
||||
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]");
|
||||
if(check_bounds_)
|
||||
to_load = "(block_k + idxT + " + kk + "< chunk_size)?" + to_load + ":0";
|
||||
if(p_.simd_width==1)
|
||||
stream << VSTORE(to_load, "0", "lBstore + " + to_string(n*lBld+k)) << ";" << std::endl;
|
||||
else
|
||||
{
|
||||
stream << vdtype << " tmpB" << k << "_" << n << " = " << to_load << ";" << std::endl;
|
||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||
stream << "lBstore[" << k + (n + s)*lBld << "]= tmpB" << k << "_" << n << ".s" << s << ";" << std::endl;
|
||||
}
|
||||
}
|
||||
stream << LocalBarrier(backend) << ";" << std::endl;
|
||||
|
||||
if(A_trans_=='N')
|
||||
stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + idx*" << p_.simd_width << ";" << std::endl;
|
||||
else
|
||||
stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + idx*" << lAld*p_.simd_width << ";" << std::endl;
|
||||
|
||||
if(B_trans_=='T')
|
||||
stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + idy*" << p_.simd_width << ";" << std::endl;
|
||||
else
|
||||
stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + idy*" << lBld*p_.simd_width << ";" << std::endl;
|
||||
|
||||
|
||||
stream << "//Inner loop" << std::endl;
|
||||
@@ -321,7 +301,17 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
stream << "for(unsigned int mm = 0; mm < " << p_.mS/p_.simd_width << "; mm++)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
if(A_trans_=='N')
|
||||
stream << "rA[kk][mm] = " << VLOAD("0", "readA + k*" + to_string(lAld) + " + mm*" + to_string(p_.local_size_0*p_.simd_width) + "+ kk*" + to_string(lAld)) << ";" << std::endl;
|
||||
else
|
||||
{
|
||||
if(p_.simd_width==1)
|
||||
stream << "rA[kk][mm] = readA[k + mm*" << p_.local_size_0*lAld << "+ kk" << "];" << std::endl;
|
||||
else
|
||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||
stream << access_vector_type("rA[kk][mm]", s) << " = readA[k + (mm*" << p_.simd_width*p_.local_size_0 << " + " << s << ")*" << lAld << "+ kk];" << std::endl;
|
||||
}
|
||||
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
|
||||
@@ -332,7 +322,16 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
stream << "for(unsigned int nn = 0; nn < " << p_.nS/p_.simd_width << "; nn++)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
if(B_trans_=='T')
|
||||
stream << "rB[kk][nn] = " << VLOAD("0", "readB + k*" + to_string(lBld) + " + nn*" + to_string(p_.local_size_1*p_.simd_width) + "+ kk*" + to_string(lBld)) << ";" << std::endl;
|
||||
else
|
||||
{
|
||||
if(p_.simd_width==1)
|
||||
stream << "rB[kk][nn] = readB[k" << " + nn*" << p_.local_size_1*lBld << "+ kk" << "];" << std::endl;
|
||||
else
|
||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||
stream << access_vector_type("rB[kk][nn]", s) << " = readB[k" << " + (nn*" << p_.simd_width*p_.local_size_1 << " + " << s << ")*" << lBld << "+ kk];" << std::endl;
|
||||
}
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
|
||||
@@ -417,7 +416,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
stream << sdtype << " acc = 0;" << std::endl;
|
||||
stream << "for(unsigned int k = 0 ; k < D ; k++)" << std::endl;
|
||||
stream << "for(unsigned int k = 0 ; k < D ; k++){" << std::endl;
|
||||
stream.inc_tab();
|
||||
stream << "acc += Z[i + j*Zld + k*Zld*N];" << std::endl;
|
||||
stream.dec_tab();
|
||||
|
@@ -115,7 +115,7 @@ def main():
|
||||
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
|
||||
|
||||
#Source files
|
||||
src = 'src/lib/array.cpp src/lib/value_scalar.cpp src/lib/wrap/clBLAS.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/expression.cpp src/lib/symbolic/io.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/context.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/event.cpp src/lib/driver/device.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/backend/parse.cpp src/lib/backend/mapped_object.cpp src/lib/backend/templates/ger.cpp src/lib/backend/templates/gemv.cpp src/lib/backend/templates/gemm.cpp src/lib/backend/templates/dot.cpp src/lib/backend/templates/base.cpp src/lib/backend/templates/axpy.cpp src/lib/backend/stream.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
|
||||
src = 'src/lib/symbolic/preset.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/backend/templates/gemv.cpp src/lib/backend/templates/axpy.cpp src/lib/backend/templates/gemm.cpp src/lib/backend/templates/ger.cpp src/lib/backend/templates/dot.cpp src/lib/backend/templates/base.cpp src/lib/backend/mapped_object.cpp src/lib/backend/stream.cpp src/lib/backend/parse.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/driver/backend.cpp src/lib/driver/device.cpp src/lib/driver/kernel.cpp src/lib/driver/buffer.cpp src/lib/driver/platform.cpp src/lib/driver/check.cpp src/lib/driver/program.cpp src/lib/driver/command_queue.cpp src/lib/driver/context.cpp src/lib/driver/event.cpp src/lib/driver/ndrange.cpp src/lib/driver/handle.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/wrap/clBLAS.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
|
||||
boostsrc = 'external/boost/libs/'
|
||||
for s in ['numpy','python','smart_ptr','system','thread']:
|
||||
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]
|
||||
|
Reference in New Issue
Block a user