GEMM: Better handling of AT=1 and BT=0
This commit is contained in:
@@ -52,8 +52,12 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
||||
unsigned int gemm::lmem_usage(expression_tree const & expression) const
|
||||
{
|
||||
unsigned int N = 0;
|
||||
N += p_.kL * p_.mL;
|
||||
N += p_.nL * p_.kL;
|
||||
size_t llda = (A_trans_=='N')?p_.mL:p_.kL+1;
|
||||
size_t lnda = (A_trans_=='N')?p_.kL:p_.mL;
|
||||
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL+1;
|
||||
size_t lndb = (B_trans_=='T')?p_.kL:p_.nL;
|
||||
N += llda*lnda;
|
||||
N += lldb*lndb;
|
||||
return N*size_of(expression.dtype());
|
||||
}
|
||||
|
||||
@@ -180,10 +184,12 @@ gemm_parameters::gemm_parameters(unsigned int vwidth
|
||||
stream << std::endl;
|
||||
|
||||
stream << "//pointers" << std::endl;
|
||||
size_t llda = (A_trans_=='N')?p_.mL:p_.kL;
|
||||
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL;
|
||||
stream << "$LOCAL " << sdtype << " lA[" << p_.kL*p_.mL << "];" << std::endl;
|
||||
stream << "$LOCAL " << sdtype << " lB[" << p_.kL*p_.nL << "];" << std::endl;
|
||||
size_t llda = (A_trans_=='N')?p_.mL:p_.kL+1;
|
||||
size_t lnda = (A_trans_=='N')?p_.kL:p_.mL;
|
||||
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL+1;
|
||||
size_t lndb = (B_trans_=='T')?p_.kL:p_.nL;
|
||||
stream << "$LOCAL " << sdtype << " lA[" << llda*lnda << "];" << std::endl;
|
||||
stream << "$LOCAL " << sdtype << " lB[" << lldb*lndb << "];" << std::endl;
|
||||
unsigned int npA = p_.mL/(A_trans_=='N'?p_.lf0*p_.vwidth:p_.lf1);
|
||||
unsigned int npB = p_.nL/(B_trans_=='T'?p_.lf0*p_.vwidth:p_.lf1);
|
||||
stream << "$GLOBAL " << sdtype << "* Ai[" << npA << "];" << std::endl;
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@@ -99,22 +99,22 @@ class Tuner:
|
||||
if level=='simple':
|
||||
sizes = [(2560,2560,2560)]
|
||||
elif level=='intermediate':
|
||||
sizes = [#Square
|
||||
(896,896,896),
|
||||
(1536,1536,1536),
|
||||
(2176, 2176,2176),
|
||||
#Rank-32 updates
|
||||
(896,896,32),
|
||||
(1536,1536,32),
|
||||
(2176,2176,32),
|
||||
#Covariance
|
||||
(32,32,16000),
|
||||
(64,64,64000),
|
||||
(256,256,32000)]
|
||||
sizes = []
|
||||
#Square
|
||||
for N in [896, 1760, 2048, 2560]:
|
||||
sizes += [(N, N, N)]
|
||||
#LaPack
|
||||
for N in [896, 1760, 2048, 2560]:
|
||||
for K in [16, 32, 64, 128]:
|
||||
sizes += [(N, N, K)]
|
||||
#Covariance
|
||||
for N in [16, 32, 64, 128]:
|
||||
for K in [16000,32000,64000,128000]:
|
||||
sizes += [(N, N, K)]
|
||||
#DeepSpeech
|
||||
for MK in [1760, 2048, 2560]:
|
||||
for N in [16, 32, 64, 128, MK]:
|
||||
sizes += [(MK, N, MK)]
|
||||
for M in [1760, 2048, 2560]:
|
||||
for N in [16, 32, 64, 128, M]:
|
||||
sizes += [(M, N, M)]
|
||||
elif level=='full':
|
||||
sizes = product(pow2range(5, 12), pow2range(5, 12), pow2range(5, 17))
|
||||
|
||||
|
Reference in New Issue
Block a user