GEMM: reverted AMD optimizations
This commit is contained in:
@@ -111,7 +111,7 @@ void bench(sc::numeric_type dtype, std::string operation)
|
|||||||
std::cout << "#" << operation << " (" << metric[operation] << ")" << std::endl;
|
std::cout << "#" << operation << " (" << metric[operation] << ")" << std::endl;
|
||||||
std::cout << "\"N\"";
|
std::cout << "\"N\"";
|
||||||
std::cout << " \"ISAAC\"";
|
std::cout << " \"ISAAC\"";
|
||||||
std::cout << " \"ISAAC (Best impl.)\"";
|
// std::cout << " \"ISAAC (Best impl.)\"";
|
||||||
#ifdef BENCH_CLBLAS
|
#ifdef BENCH_CLBLAS
|
||||||
std::cout << " \"clBLAS\"";
|
std::cout << " \"clBLAS\"";
|
||||||
#endif
|
#endif
|
||||||
@@ -314,7 +314,7 @@ void bench(sc::numeric_type dtype, std::string operation)
|
|||||||
int_t lda = A.stride()[1], ldb = B.stride()[1], ldc = C.stride()[1];
|
int_t lda = A.stride()[1], ldb = B.stride()[1], ldc = C.stride()[1];
|
||||||
#endif
|
#endif
|
||||||
BENCHMARK_ISAAC(C = sc::execution_handler(AT?(BT?dot(A.T,B.T):dot(A.T,B)):(BT?dot(A,B.T):dot(A,B)), sc::execution_options_type(0, &events), sc::dispatcher_options_type(false)), (double)2*M*N*K/t);
|
BENCHMARK_ISAAC(C = sc::execution_handler(AT?(BT?dot(A.T,B.T):dot(A.T,B)):(BT?dot(A,B.T):dot(A,B)), sc::execution_options_type(0, &events), sc::dispatcher_options_type(false)), (double)2*M*N*K/t);
|
||||||
BENCHMARK_ISAAC(C = sc::execution_handler(AT?(BT?dot(A.T,B.T):dot(A.T,B)):(BT?dot(A,B.T):dot(A,B)), sc::execution_options_type(0, &events), sc::dispatcher_options_type(true)), (double)2*M*N*K/t);
|
// BENCHMARK_ISAAC(C = sc::execution_handler(AT?(BT?dot(A.T,B.T):dot(A.T,B)):(BT?dot(A,B.T):dot(A,B)), sc::execution_options_type(0, &events), sc::dispatcher_options_type(true)), (double)2*M*N*K/t);
|
||||||
/* clblas */
|
/* clblas */
|
||||||
#ifdef BENCH_CLBLAS
|
#ifdef BENCH_CLBLAS
|
||||||
if(C.context().backend()==sc::driver::OPENCL)
|
if(C.context().backend()==sc::driver::OPENCL)
|
||||||
|
@@ -282,17 +282,29 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
|
|
||||||
stream << std::endl;
|
stream << std::endl;
|
||||||
stream << "//Outer loop" << std::endl;
|
stream << "//Outer loop" << std::endl;
|
||||||
stream << "while(K >=" << p_.kL << ")" << std::endl;
|
stream << "while(K > 0)" << std::endl;
|
||||||
stream << "{" << std::endl;
|
stream << "{" << std::endl;
|
||||||
stream.inc_tab();
|
stream.inc_tab();
|
||||||
|
|
||||||
|
auto do_fetch = [&](bool last_iteration)
|
||||||
auto fetch_to_lds = [&](bool last_iteration)
|
|
||||||
{
|
{
|
||||||
stream << LocalBarrier(backend) << ";" << std::endl;
|
if(last_iteration)
|
||||||
stream << LocalPtr(backend) << " " << sdtype << "* ldsA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
|
{
|
||||||
stream << LocalPtr(backend) << " " << sdtype << "* ldsB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
|
if(A_trans_=='N' || B_trans_=='T')
|
||||||
|
{
|
||||||
|
stream << "int Ky = K - idT.y;" << std::endl;
|
||||||
|
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_1)
|
||||||
|
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(A_trans_=='T' || B_trans_=='N')
|
||||||
|
{
|
||||||
|
stream << "int Kx = K - idT.x;" << std::endl;
|
||||||
|
for(unsigned int k = 0 ; k < p_.kL ; k += p_.local_fetch_0*p_.simd_width)
|
||||||
|
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||||
|
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
stream << "//Fetch A to local memory" << std::endl;
|
stream << "//Fetch A to local memory" << std::endl;
|
||||||
if (A_trans_=='N')
|
if (A_trans_=='N')
|
||||||
{
|
{
|
||||||
@@ -354,6 +366,20 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
stream << VSTORE(VLOAD_MISALIGNED("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]"), "0", "ldsB + " + to_string(n*lldb+k)) << ";" << std::endl;
|
stream << VSTORE(VLOAD_MISALIGNED("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]"), "0", "ldsB + " + to_string(n*lldb+k)) << ";" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
stream << LocalBarrier(backend) << ";" << std::endl;
|
||||||
|
stream << LocalPtr(backend) << " " << sdtype << "* ldsA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
|
||||||
|
stream << LocalPtr(backend) << " " << sdtype << "* ldsB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
|
||||||
|
|
||||||
|
stream << "if(K >= " << p_.kL << ")" << std::endl;
|
||||||
|
stream << "{" << std::endl;
|
||||||
|
do_fetch(false);
|
||||||
|
stream << "}" << std::endl;
|
||||||
|
stream << "else{" << std::endl;
|
||||||
|
do_fetch(true);
|
||||||
|
stream << "}" << std::endl;
|
||||||
|
|
||||||
if(A_trans_=='N')
|
if(A_trans_=='N')
|
||||||
stream << "ldsA = lA + ids.z*" << p_.simd_width << ";" << std::endl;
|
stream << "ldsA = lA + ids.z*" << p_.simd_width << ";" << std::endl;
|
||||||
@@ -427,7 +453,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
rhs_str = "rB[" + to_string(kk) + "]["+to_string(nn)+"]";
|
rhs_str = "rB[" + to_string(kk) + "]["+to_string(nn)+"]";
|
||||||
else
|
else
|
||||||
rhs_str = access_vector_type("rB[" + to_string(kk) + "]["+to_string(nn/p_.simd_width)+"]", nn%p_.simd_width);
|
rhs_str = access_vector_type("rB[" + to_string(kk) + "]["+to_string(nn/p_.simd_width)+"]", nn%p_.simd_width);
|
||||||
stream << res_str << "=" << "fma(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl;
|
stream << res_str << "=" << "mad(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
@@ -454,31 +480,12 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
|
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
|
||||||
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
fetch_to_lds(false);
|
|
||||||
|
|
||||||
|
|
||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
|
|
||||||
|
|
||||||
if(A_trans_=='N' || B_trans_=='T')
|
|
||||||
{
|
|
||||||
stream << "int Ky = K - idT.y;" << std::endl;
|
|
||||||
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_1)
|
|
||||||
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
if(A_trans_=='T' || B_trans_=='N')
|
// fetch_to_lds(true);
|
||||||
{
|
|
||||||
stream << "int Kx = K - idT.x;" << std::endl;
|
|
||||||
for(unsigned int k = 0 ; k < p_.kL ; k += p_.local_fetch_0*p_.simd_width)
|
|
||||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
|
||||||
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
|
|
||||||
}
|
|
||||||
fetch_to_lds(true);
|
|
||||||
|
|
||||||
stream << "//Write back C" << std::endl;
|
stream << "//Write back C" << std::endl;
|
||||||
stream << "M += ids.x;" << std::endl;
|
stream << "M += ids.x;" << std::endl;
|
||||||
@@ -516,10 +523,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
{
|
{
|
||||||
string Ci = to_string((m/p_.simd_width)*(p_.local_size_0*p_.simd_width) + m%p_.simd_width);
|
string Ci = to_string((m/p_.simd_width)*(p_.local_size_0*p_.simd_width) + m%p_.simd_width);
|
||||||
stream << "if(" << Ci << "< M) ";
|
stream << "if(" << Ci << "< M) ";
|
||||||
if(has_depth)
|
|
||||||
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "];" << std::endl;
|
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "];" << std::endl;
|
||||||
else
|
|
||||||
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "] + (beta?(beta*" << "C[" << Ci << CSTRIDE1 << "]):0);" << std::endl;
|
|
||||||
}
|
}
|
||||||
if((n+1)%p_.simd_width==0){
|
if((n+1)%p_.simd_width==0){
|
||||||
stream << "C += ldc*" << p_.local_size_1*p_.simd_width - p_.simd_width + 1 << ";" << std::endl;
|
stream << "C += ldc*" << p_.local_size_1*p_.simd_width - p_.simd_width + 1 << ";" << std::endl;
|
||||||
@@ -564,10 +568,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
return stream.str();
|
return stream.str();
|
||||||
|
|
||||||
#undef VLOAD
|
|
||||||
#undef VST0RE
|
#undef VST0RE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user