GEMM: Better handling of AT=1 and BT=0

This commit is contained in:
Philippe Tillet
2016-10-02 17:37:49 -04:00
parent e1baf85707
commit 77178d7017
4 changed files with 2250 additions and 9867 deletions

View File

@@ -52,8 +52,12 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
unsigned int gemm::lmem_usage(expression_tree const & expression) const unsigned int gemm::lmem_usage(expression_tree const & expression) const
{ {
unsigned int N = 0; unsigned int N = 0;
N += p_.kL * p_.mL; size_t llda = (A_trans_=='N')?p_.mL:p_.kL+1;
N += p_.nL * p_.kL; size_t lnda = (A_trans_=='N')?p_.kL:p_.mL;
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL+1;
size_t lndb = (B_trans_=='T')?p_.kL:p_.nL;
N += llda*lnda;
N += lldb*lndb;
return N*size_of(expression.dtype()); return N*size_of(expression.dtype());
} }
@@ -180,10 +184,12 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
stream << std::endl; stream << std::endl;
stream << "//pointers" << std::endl; stream << "//pointers" << std::endl;
size_t llda = (A_trans_=='N')?p_.mL:p_.kL; size_t llda = (A_trans_=='N')?p_.mL:p_.kL+1;
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL; size_t lnda = (A_trans_=='N')?p_.kL:p_.mL;
stream << "$LOCAL " << sdtype << " lA[" << p_.kL*p_.mL << "];" << std::endl; size_t lldb = (B_trans_=='T')?p_.nL:p_.kL+1;
stream << "$LOCAL " << sdtype << " lB[" << p_.kL*p_.nL << "];" << std::endl; size_t lndb = (B_trans_=='T')?p_.kL:p_.nL;
stream << "$LOCAL " << sdtype << " lA[" << llda*lnda << "];" << std::endl;
stream << "$LOCAL " << sdtype << " lB[" << lldb*lndb << "];" << std::endl;
unsigned int npA = p_.mL/(A_trans_=='N'?p_.lf0*p_.vwidth:p_.lf1); unsigned int npA = p_.mL/(A_trans_=='N'?p_.lf0*p_.vwidth:p_.lf1);
unsigned int npB = p_.nL/(B_trans_=='T'?p_.lf0*p_.vwidth:p_.lf1); unsigned int npB = p_.nL/(B_trans_=='T'?p_.lf0*p_.vwidth:p_.lf1);
stream << "$GLOBAL " << sdtype << "* Ai[" << npA << "];" << std::endl; stream << "$GLOBAL " << sdtype << "* Ai[" << npA << "];" << std::endl;

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@@ -99,22 +99,22 @@ class Tuner:
if level=='simple': if level=='simple':
sizes = [(2560,2560,2560)] sizes = [(2560,2560,2560)]
elif level=='intermediate': elif level=='intermediate':
sizes = [#Square sizes = []
(896,896,896), #Square
(1536,1536,1536), for N in [896, 1760, 2048, 2560]:
(2176, 2176,2176), sizes += [(N, N, N)]
#Rank-32 updates #LaPack
(896,896,32), for N in [896, 1760, 2048, 2560]:
(1536,1536,32), for K in [16, 32, 64, 128]:
(2176,2176,32), sizes += [(N, N, K)]
#Covariance #Covariance
(32,32,16000), for N in [16, 32, 64, 128]:
(64,64,64000), for K in [16000,32000,64000,128000]:
(256,256,32000)] sizes += [(N, N, K)]
#DeepSpeech #DeepSpeech
for MK in [1760, 2048, 2560]: for M in [1760, 2048, 2560]:
for N in [16, 32, 64, 128, MK]: for N in [16, 32, 64, 128, M]:
sizes += [(MK, N, MK)] sizes += [(M, N, M)]
elif level=='full': elif level=='full':
sizes = product(pow2range(5, 12), pow2range(5, 12), pow2range(5, 17)) sizes = product(pow2range(5, 12), pow2range(5, 12), pow2range(5, 17))