GEMM: reverted AMD optimizations

This commit is contained in:
Philippe Tillet
2015-11-29 16:13:14 -05:00
parent b3c5251f91
commit f975ea7621
2 changed files with 664 additions and 663 deletions

View File

@@ -111,7 +111,7 @@ void bench(sc::numeric_type dtype, std::string operation)
std::cout << "#" << operation << " (" << metric[operation] << ")" << std::endl; std::cout << "#" << operation << " (" << metric[operation] << ")" << std::endl;
std::cout << "\"N\""; std::cout << "\"N\"";
std::cout << " \"ISAAC\""; std::cout << " \"ISAAC\"";
std::cout << " \"ISAAC (Best impl.)\""; // std::cout << " \"ISAAC (Best impl.)\"";
#ifdef BENCH_CLBLAS #ifdef BENCH_CLBLAS
std::cout << " \"clBLAS\""; std::cout << " \"clBLAS\"";
#endif #endif
@@ -314,7 +314,7 @@ void bench(sc::numeric_type dtype, std::string operation)
int_t lda = A.stride()[1], ldb = B.stride()[1], ldc = C.stride()[1]; int_t lda = A.stride()[1], ldb = B.stride()[1], ldc = C.stride()[1];
#endif #endif
BENCHMARK_ISAAC(C = sc::execution_handler(AT?(BT?dot(A.T,B.T):dot(A.T,B)):(BT?dot(A,B.T):dot(A,B)), sc::execution_options_type(0, &events), sc::dispatcher_options_type(false)), (double)2*M*N*K/t); BENCHMARK_ISAAC(C = sc::execution_handler(AT?(BT?dot(A.T,B.T):dot(A.T,B)):(BT?dot(A,B.T):dot(A,B)), sc::execution_options_type(0, &events), sc::dispatcher_options_type(false)), (double)2*M*N*K/t);
BENCHMARK_ISAAC(C = sc::execution_handler(AT?(BT?dot(A.T,B.T):dot(A.T,B)):(BT?dot(A,B.T):dot(A,B)), sc::execution_options_type(0, &events), sc::dispatcher_options_type(true)), (double)2*M*N*K/t); // BENCHMARK_ISAAC(C = sc::execution_handler(AT?(BT?dot(A.T,B.T):dot(A.T,B)):(BT?dot(A,B.T):dot(A,B)), sc::execution_options_type(0, &events), sc::dispatcher_options_type(true)), (double)2*M*N*K/t);
/* clblas */ /* clblas */
#ifdef BENCH_CLBLAS #ifdef BENCH_CLBLAS
if(C.context().backend()==sc::driver::OPENCL) if(C.context().backend()==sc::driver::OPENCL)

View File

@@ -282,17 +282,29 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << std::endl; stream << std::endl;
stream << "//Outer loop" << std::endl; stream << "//Outer loop" << std::endl;
stream << "while(K >=" << p_.kL << ")" << std::endl; stream << "while(K > 0)" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
auto do_fetch = [&](bool last_iteration)
auto fetch_to_lds = [&](bool last_iteration)
{ {
stream << LocalBarrier(backend) << ";" << std::endl; if(last_iteration)
stream << LocalPtr(backend) << " " << sdtype << "* ldsA = lA + idT.y*" << llda << " + idT.x;" << std::endl; {
stream << LocalPtr(backend) << " " << sdtype << "* ldsB = lB + idT.y*" << lldb << " + idT.x;" << std::endl; if(A_trans_=='N' || B_trans_=='T')
{
stream << "int Ky = K - idT.y;" << std::endl;
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_1)
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
}
if(A_trans_=='T' || B_trans_=='N')
{
stream << "int Kx = K - idT.x;" << std::endl;
for(unsigned int k = 0 ; k < p_.kL ; k += p_.local_fetch_0*p_.simd_width)
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
}
}
stream << "//Fetch A to local memory" << std::endl; stream << "//Fetch A to local memory" << std::endl;
if (A_trans_=='N') if (A_trans_=='N')
{ {
@@ -354,6 +366,20 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << VSTORE(VLOAD_MISALIGNED("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]"), "0", "ldsB + " + to_string(n*lldb+k)) << ";" << std::endl; stream << VSTORE(VLOAD_MISALIGNED("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]"), "0", "ldsB + " + to_string(n*lldb+k)) << ";" << std::endl;
} }
} }
};
stream << LocalBarrier(backend) << ";" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* ldsA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
stream << LocalPtr(backend) << " " << sdtype << "* ldsB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
stream << "if(K >= " << p_.kL << ")" << std::endl;
stream << "{" << std::endl;
do_fetch(false);
stream << "}" << std::endl;
stream << "else{" << std::endl;
do_fetch(true);
stream << "}" << std::endl;
if(A_trans_=='N') if(A_trans_=='N')
stream << "ldsA = lA + ids.z*" << p_.simd_width << ";" << std::endl; stream << "ldsA = lA + ids.z*" << p_.simd_width << ";" << std::endl;
@@ -427,7 +453,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
rhs_str = "rB[" + to_string(kk) + "]["+to_string(nn)+"]"; rhs_str = "rB[" + to_string(kk) + "]["+to_string(nn)+"]";
else else
rhs_str = access_vector_type("rB[" + to_string(kk) + "]["+to_string(nn/p_.simd_width)+"]", nn%p_.simd_width); rhs_str = access_vector_type("rB[" + to_string(kk) + "]["+to_string(nn/p_.simd_width)+"]", nn%p_.simd_width);
stream << res_str << "=" << "fma(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl; stream << res_str << "=" << "mad(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl;
} }
stream.dec_tab(); stream.dec_tab();
@@ -454,31 +480,12 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl; stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
};
fetch_to_lds(false);
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
if(A_trans_=='N' || B_trans_=='T')
{
stream << "int Ky = K - idT.y;" << std::endl;
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_1)
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
}
if(A_trans_=='T' || B_trans_=='N') // fetch_to_lds(true);
{
stream << "int Kx = K - idT.x;" << std::endl;
for(unsigned int k = 0 ; k < p_.kL ; k += p_.local_fetch_0*p_.simd_width)
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
}
fetch_to_lds(true);
stream << "//Write back C" << std::endl; stream << "//Write back C" << std::endl;
stream << "M += ids.x;" << std::endl; stream << "M += ids.x;" << std::endl;
@@ -516,10 +523,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
{ {
string Ci = to_string((m/p_.simd_width)*(p_.local_size_0*p_.simd_width) + m%p_.simd_width); string Ci = to_string((m/p_.simd_width)*(p_.local_size_0*p_.simd_width) + m%p_.simd_width);
stream << "if(" << Ci << "< M) "; stream << "if(" << Ci << "< M) ";
if(has_depth)
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "];" << std::endl; stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "];" << std::endl;
else
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "] + (beta?(beta*" << "C[" << Ci << CSTRIDE1 << "]):0);" << std::endl;
} }
if((n+1)%p_.simd_width==0){ if((n+1)%p_.simd_width==0){
stream << "C += ldc*" << p_.local_size_1*p_.simd_width - p_.simd_width + 1 << ";" << std::endl; stream << "C += ldc*" << p_.local_size_1*p_.simd_width - p_.simd_width + 1 << ";" << std::endl;
@@ -564,10 +568,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
} }
return stream.str(); return stream.str();
#undef VLOAD
#undef VST0RE #undef VST0RE
} }