GEMM: reverted AMD optimizations
This commit is contained in:
@@ -111,7 +111,7 @@ void bench(sc::numeric_type dtype, std::string operation)
|
||||
std::cout << "#" << operation << " (" << metric[operation] << ")" << std::endl;
|
||||
std::cout << "\"N\"";
|
||||
std::cout << " \"ISAAC\"";
|
||||
std::cout << " \"ISAAC (Best impl.)\"";
|
||||
// std::cout << " \"ISAAC (Best impl.)\"";
|
||||
#ifdef BENCH_CLBLAS
|
||||
std::cout << " \"clBLAS\"";
|
||||
#endif
|
||||
@@ -314,7 +314,7 @@ void bench(sc::numeric_type dtype, std::string operation)
|
||||
int_t lda = A.stride()[1], ldb = B.stride()[1], ldc = C.stride()[1];
|
||||
#endif
|
||||
BENCHMARK_ISAAC(C = sc::execution_handler(AT?(BT?dot(A.T,B.T):dot(A.T,B)):(BT?dot(A,B.T):dot(A,B)), sc::execution_options_type(0, &events), sc::dispatcher_options_type(false)), (double)2*M*N*K/t);
|
||||
BENCHMARK_ISAAC(C = sc::execution_handler(AT?(BT?dot(A.T,B.T):dot(A.T,B)):(BT?dot(A,B.T):dot(A,B)), sc::execution_options_type(0, &events), sc::dispatcher_options_type(true)), (double)2*M*N*K/t);
|
||||
// BENCHMARK_ISAAC(C = sc::execution_handler(AT?(BT?dot(A.T,B.T):dot(A.T,B)):(BT?dot(A,B.T):dot(A,B)), sc::execution_options_type(0, &events), sc::dispatcher_options_type(true)), (double)2*M*N*K/t);
|
||||
/* clblas */
|
||||
#ifdef BENCH_CLBLAS
|
||||
if(C.context().backend()==sc::driver::OPENCL)
|
||||
|
@@ -27,36 +27,36 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
}
|
||||
|
||||
|
||||
unsigned int gemm::lmem_usage(math_expression const & expression) const
|
||||
{
|
||||
unsigned int gemm::lmem_usage(math_expression const & expression) const
|
||||
{
|
||||
numeric_type numeric_t = lhs_most(expression.tree(), expression.root()).lhs.dtype;
|
||||
unsigned int N = 0;
|
||||
N += p_.kL * p_.mL;
|
||||
N += p_.nL * p_.kL;
|
||||
return N*size_of(numeric_t);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned int gemm::registers_usage(math_expression const & expression) const
|
||||
{
|
||||
unsigned int gemm::registers_usage(math_expression const & expression) const
|
||||
{
|
||||
numeric_type numeric_t = lhs_most(expression.tree(), expression.root()).lhs.dtype;
|
||||
|
||||
unsigned int N = p_.mS * p_.nS + p_.mS * p_.kS + p_.kS * p_.nS;
|
||||
return N*size_of(numeric_t);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned int gemm::temporary_workspace(math_expression const & expressions) const
|
||||
{
|
||||
unsigned int gemm::temporary_workspace(math_expression const & expressions) const
|
||||
{
|
||||
std::vector<int_t> MNK = input_sizes(expressions);
|
||||
int_t M = MNK[0]; int_t N = MNK[1];
|
||||
if(p_.depth > 1)
|
||||
return M*N*p_.depth;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
int gemm::is_invalid_impl(driver::Device const &, math_expression const &) const
|
||||
{
|
||||
// if(device.vendor()==driver::Device::Vendor::NVIDIA && p_.simd_width > 1)
|
||||
// return TEMPLATE_INVALID_SIMD_WIDTH;
|
||||
int gemm::is_invalid_impl(driver::Device const &, math_expression const &) const
|
||||
{
|
||||
// if(device.vendor()==driver::Device::Vendor::NVIDIA && p_.simd_width > 1)
|
||||
// return TEMPLATE_INVALID_SIMD_WIDTH;
|
||||
|
||||
if(p_.A_fetching_policy!=FETCH_FROM_LOCAL || p_.B_fetching_policy!=FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
@@ -101,10 +101,10 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
}
|
||||
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
}
|
||||
|
||||
std::string gemm::generate_impl(std::string const & suffix, math_expression const & expression, driver::Device const & device, mapping_type const &) const
|
||||
{
|
||||
std::string gemm::generate_impl(std::string const & suffix, math_expression const & expression, driver::Device const & device, mapping_type const &) const
|
||||
{
|
||||
using std::string;
|
||||
using tools::to_string;
|
||||
|
||||
@@ -282,17 +282,29 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
|
||||
stream << std::endl;
|
||||
stream << "//Outer loop" << std::endl;
|
||||
stream << "while(K >=" << p_.kL << ")" << std::endl;
|
||||
stream << "while(K > 0)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
|
||||
auto fetch_to_lds = [&](bool last_iteration)
|
||||
auto do_fetch = [&](bool last_iteration)
|
||||
{
|
||||
stream << LocalBarrier(backend) << ";" << std::endl;
|
||||
stream << LocalPtr(backend) << " " << sdtype << "* ldsA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
|
||||
stream << LocalPtr(backend) << " " << sdtype << "* ldsB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
|
||||
if(last_iteration)
|
||||
{
|
||||
if(A_trans_=='N' || B_trans_=='T')
|
||||
{
|
||||
stream << "int Ky = K - idT.y;" << std::endl;
|
||||
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_1)
|
||||
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
|
||||
}
|
||||
|
||||
if(A_trans_=='T' || B_trans_=='N')
|
||||
{
|
||||
stream << "int Kx = K - idT.x;" << std::endl;
|
||||
for(unsigned int k = 0 ; k < p_.kL ; k += p_.local_fetch_0*p_.simd_width)
|
||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
|
||||
}
|
||||
}
|
||||
stream << "//Fetch A to local memory" << std::endl;
|
||||
if (A_trans_=='N')
|
||||
{
|
||||
@@ -354,6 +366,20 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
stream << VSTORE(VLOAD_MISALIGNED("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]"), "0", "ldsB + " + to_string(n*lldb+k)) << ";" << std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
stream << LocalBarrier(backend) << ";" << std::endl;
|
||||
stream << LocalPtr(backend) << " " << sdtype << "* ldsA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
|
||||
stream << LocalPtr(backend) << " " << sdtype << "* ldsB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
|
||||
|
||||
stream << "if(K >= " << p_.kL << ")" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
do_fetch(false);
|
||||
stream << "}" << std::endl;
|
||||
stream << "else{" << std::endl;
|
||||
do_fetch(true);
|
||||
stream << "}" << std::endl;
|
||||
|
||||
if(A_trans_=='N')
|
||||
stream << "ldsA = lA + ids.z*" << p_.simd_width << ";" << std::endl;
|
||||
@@ -427,7 +453,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
rhs_str = "rB[" + to_string(kk) + "]["+to_string(nn)+"]";
|
||||
else
|
||||
rhs_str = access_vector_type("rB[" + to_string(kk) + "]["+to_string(nn/p_.simd_width)+"]", nn%p_.simd_width);
|
||||
stream << res_str << "=" << "fma(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl;
|
||||
stream << res_str << "=" << "mad(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl;
|
||||
}
|
||||
|
||||
stream.dec_tab();
|
||||
@@ -454,31 +480,12 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
|
||||
|
||||
|
||||
};
|
||||
|
||||
|
||||
fetch_to_lds(false);
|
||||
|
||||
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
|
||||
|
||||
if(A_trans_=='N' || B_trans_=='T')
|
||||
{
|
||||
stream << "int Ky = K - idT.y;" << std::endl;
|
||||
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_1)
|
||||
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
|
||||
}
|
||||
|
||||
if(A_trans_=='T' || B_trans_=='N')
|
||||
{
|
||||
stream << "int Kx = K - idT.x;" << std::endl;
|
||||
for(unsigned int k = 0 ; k < p_.kL ; k += p_.local_fetch_0*p_.simd_width)
|
||||
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
|
||||
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
|
||||
}
|
||||
fetch_to_lds(true);
|
||||
// fetch_to_lds(true);
|
||||
|
||||
stream << "//Write back C" << std::endl;
|
||||
stream << "M += ids.x;" << std::endl;
|
||||
@@ -516,10 +523,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
{
|
||||
string Ci = to_string((m/p_.simd_width)*(p_.local_size_0*p_.simd_width) + m%p_.simd_width);
|
||||
stream << "if(" << Ci << "< M) ";
|
||||
if(has_depth)
|
||||
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "];" << std::endl;
|
||||
else
|
||||
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "] + (beta?(beta*" << "C[" << Ci << CSTRIDE1 << "]):0);" << std::endl;
|
||||
}
|
||||
if((n+1)%p_.simd_width==0){
|
||||
stream << "C += ldc*" << p_.local_size_1*p_.simd_width - p_.simd_width + 1 << ";" << std::endl;
|
||||
@@ -564,18 +568,15 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
}
|
||||
|
||||
return stream.str();
|
||||
|
||||
#undef VLOAD
|
||||
#undef VST0RE
|
||||
}
|
||||
}
|
||||
|
||||
void gemm::enqueue_block(driver::CommandQueue & /*queue*/, int_t M, int_t N, int_t K,
|
||||
void gemm::enqueue_block(driver::CommandQueue & /*queue*/, int_t M, int_t N, int_t K,
|
||||
array_base const & A, array_base const & B, array_base const & C,
|
||||
value_scalar const & alpha, value_scalar const & beta,
|
||||
driver::Program const & program, std::string const & suffix, execution_options_type const & options)
|
||||
{
|
||||
{
|
||||
using tools::align;
|
||||
|
||||
if(M==0 || N==0 || K==0)
|
||||
@@ -649,10 +650,10 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
options.enqueue(program.context(), reduce, global, local);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int_t> gemm::infos(math_expression const & expression, symbolic::preset::gemm::args& arguments) const
|
||||
{
|
||||
std::vector<int_t> gemm::infos(math_expression const & expression, symbolic::preset::gemm::args& arguments) const
|
||||
{
|
||||
math_expression::container_type const & array = expression.tree();
|
||||
std::size_t root = expression.root();
|
||||
arguments = symbolic::preset::gemm::check(array, root);
|
||||
@@ -660,25 +661,25 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
int_t N = arguments.C->array->shape()[1];
|
||||
int_t K = (A_trans_=='T')?arguments.A->array->shape()[0]:arguments.A->array->shape()[1];
|
||||
return {M, N, K};
|
||||
}
|
||||
}
|
||||
|
||||
gemm::gemm(gemm_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<gemm, gemm_parameters>(parameters, BIND_INDEPENDENT), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
|
||||
{
|
||||
gemm::gemm(gemm_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<gemm, gemm_parameters>(parameters, BIND_INDEPENDENT), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
|
||||
{
|
||||
if(A_trans_=='N' && B_trans_=='N') type_ = GEMM_NN_TYPE;
|
||||
else if(A_trans_=='T' && B_trans_=='N') type_ = GEMM_TN_TYPE;
|
||||
else if(A_trans_=='N' && B_trans_=='T') type_ = GEMM_NT_TYPE;
|
||||
else if(A_trans_=='T' && B_trans_=='T') type_ = GEMM_TT_TYPE;
|
||||
else throw;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int_t> gemm::input_sizes(math_expression const & expressions) const
|
||||
{
|
||||
std::vector<int_t> gemm::input_sizes(math_expression const & expressions) const
|
||||
{
|
||||
symbolic::preset::gemm::args dummy;
|
||||
return infos((math_expression&)expressions, dummy);
|
||||
}
|
||||
}
|
||||
|
||||
void gemm::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback_base, execution_handler const & control)
|
||||
{
|
||||
void gemm::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback_base, execution_handler const & control)
|
||||
{
|
||||
using namespace tools;
|
||||
|
||||
gemm & fallback = (gemm&)fallback_base;
|
||||
@@ -716,44 +717,44 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
{
|
||||
enqueue_block(queue, M, N, K, *pA, *pB, *pC, args.alpha, args.beta, program, suffix, options);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
gemm_nn::gemm_nn(unsigned int simd
|
||||
//
|
||||
gemm_nn::gemm_nn(unsigned int simd
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N')
|
||||
{
|
||||
}
|
||||
{
|
||||
}
|
||||
|
||||
//
|
||||
gemm_tn::gemm_tn(unsigned int simd
|
||||
//
|
||||
gemm_tn::gemm_tn(unsigned int simd
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N')
|
||||
{ }
|
||||
{ }
|
||||
|
||||
//
|
||||
gemm_nt::gemm_nt(unsigned int simd
|
||||
//
|
||||
gemm_nt::gemm_nt(unsigned int simd
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T')
|
||||
{ }
|
||||
{ }
|
||||
|
||||
//
|
||||
gemm_tt::gemm_tt(unsigned int simd
|
||||
//
|
||||
gemm_tt::gemm_tt(unsigned int simd
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T')
|
||||
{ }
|
||||
{ }
|
||||
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user