Backend: A lot of bugfixes in dot() for handling shapes better

This commit is contained in:
Philippe Tillet
2015-06-30 17:55:57 -04:00
parent e7cabf65ac
commit cf2dba43ef
12 changed files with 108 additions and 73 deletions

View File

@@ -567,7 +567,6 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
value_scalar const & alpha, value_scalar const & beta,
driver::Program & program, const char * suffix, execution_options_type const & options)
{
if(M==0 || N==0 || K==0)
return;
@@ -588,8 +587,10 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
driver::Kernel gemm(program, gemm_name);
driver::NDRange local(p_.local_size_0, p_.local_size_1);
using tools::align;
driver::NDRange global = (strcmp(suffix,"fallback")==0)?driver::NDRange(align(align(M,p_.mS)/p_.mS, p_.local_size_0), align(align(N,p_.nS)/p_.nS, p_.local_size_1), p_.depth):driver::NDRange(M/p_.mS, N/p_.nS, p_.depth);
unsigned int current_arg = 0;
set_arguments_functor helper(binder, current_arg, gemm);
gemm.setSizeArg(current_arg++, M);
@@ -611,9 +612,14 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
gemm.setSizeArg(current_arg++, B.start()[0] + B.start()[1]*B.ld()/p_.simd_width);
gemm.setSizeArg(current_arg++, B.stride()[0]);
// std::cout << "before " << *out << std::endl;
helper.set_arguments(beta.dtype(), beta.values());
options.enqueue(program.context(), gemm, global, local);
options.queue(program.context()).synchronize();
// std::cout << "after " << *out << std::endl;
if(p_.depth > 1)
{
unsigned int current_arg = 0;