Backend: A lot of bugfixes in dot() for handling shapes better
This commit is contained in:
@@ -47,7 +47,8 @@ namespace detail
|
||||
|
||||
bool bypass(op_element const & op)
|
||||
{
|
||||
return op.type == OPERATOR_RESHAPE_TYPE;
|
||||
return op.type == OPERATOR_RESHAPE_TYPE
|
||||
||op.type == OPERATOR_TRANS_TYPE;
|
||||
}
|
||||
|
||||
bool is_cast(op_element const & op)
|
||||
@@ -68,8 +69,7 @@ namespace detail
|
||||
|
||||
bool is_node_leaf(op_element const & op)
|
||||
{
|
||||
return op.type==OPERATOR_TRANS_TYPE
|
||||
|| op.type==OPERATOR_MATRIX_DIAG_TYPE
|
||||
return op.type==OPERATOR_MATRIX_DIAG_TYPE
|
||||
|| op.type==OPERATOR_VDIAG_TYPE
|
||||
|| op.type==OPERATOR_REPEAT_TYPE
|
||||
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|
||||
@@ -212,8 +212,6 @@ const char * evaluate(operation_node_type type)
|
||||
case OPERATOR_ELEMENT_FMIN_TYPE : return "fmin";
|
||||
case OPERATOR_ELEMENT_MAX_TYPE : return "max";
|
||||
case OPERATOR_ELEMENT_MIN_TYPE : return "min";
|
||||
//Unary
|
||||
case OPERATOR_TRANS_TYPE : return "trans";
|
||||
|
||||
//Binary
|
||||
case OPERATOR_MATRIX_PRODUCT_NN_TYPE : return "prodNN";
|
||||
|
@@ -82,6 +82,13 @@ std::string maxpy::generate_impl(const char * suffix, expressions_tuple const &
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
|
||||
stream << "if(" << GlobalIdx0(backend) << "==0 &&" << GlobalIdx1(backend) << "==0)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
process(stream, LHS_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("array0", "#pointer[#start] = #namereg;"), expressions, mappings);
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
|
||||
|
@@ -567,7 +567,6 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
value_scalar const & alpha, value_scalar const & beta,
|
||||
driver::Program & program, const char * suffix, execution_options_type const & options)
|
||||
{
|
||||
|
||||
if(M==0 || N==0 || K==0)
|
||||
return;
|
||||
|
||||
@@ -588,8 +587,10 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
driver::Kernel gemm(program, gemm_name);
|
||||
driver::NDRange local(p_.local_size_0, p_.local_size_1);
|
||||
|
||||
|
||||
using tools::align;
|
||||
driver::NDRange global = (strcmp(suffix,"fallback")==0)?driver::NDRange(align(align(M,p_.mS)/p_.mS, p_.local_size_0), align(align(N,p_.nS)/p_.nS, p_.local_size_1), p_.depth):driver::NDRange(M/p_.mS, N/p_.nS, p_.depth);
|
||||
|
||||
unsigned int current_arg = 0;
|
||||
set_arguments_functor helper(binder, current_arg, gemm);
|
||||
gemm.setSizeArg(current_arg++, M);
|
||||
@@ -611,9 +612,14 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
gemm.setSizeArg(current_arg++, B.start()[0] + B.start()[1]*B.ld()/p_.simd_width);
|
||||
gemm.setSizeArg(current_arg++, B.stride()[0]);
|
||||
|
||||
// std::cout << "before " << *out << std::endl;
|
||||
|
||||
helper.set_arguments(beta.dtype(), beta.values());
|
||||
options.enqueue(program.context(), gemm, global, local);
|
||||
|
||||
options.queue(program.context()).synchronize();
|
||||
// std::cout << "after " << *out << std::endl;
|
||||
|
||||
if(p_.depth > 1)
|
||||
{
|
||||
unsigned int current_arg = 0;
|
||||
|
@@ -33,6 +33,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
{
|
||||
using tools::to_string;
|
||||
|
||||
|
||||
std::vector<mapped_mreduction*> reductions;
|
||||
expressions_tuple::data_type::const_iterator sit;
|
||||
std::vector<mapping_type>::const_iterator mit;
|
||||
@@ -114,6 +115,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
{
|
||||
std::string data_type = append_width("#scalartype",simd_width);
|
||||
|
||||
|
||||
for (const auto & e : reductions)
|
||||
{
|
||||
std::map<std::string, std::string> accessors;
|
||||
@@ -130,7 +132,6 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
e->process_recursive(stream, PARENT_NODE_TYPE, accessors);
|
||||
}
|
||||
|
||||
|
||||
//Update accumulators
|
||||
std::vector<std::string> str(simd_width);
|
||||
if (simd_width==1)
|
||||
@@ -240,6 +241,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
stream << _size_t << " gsize1 = " << GlobalSize1(backend) <<";" << std::endl;
|
||||
|
||||
|
||||
|
||||
stream << _size_t << " upper_bound_1 = ( M +" << p_.local_size_1 - 1 << ")/" << p_.local_size_1 << "*" << p_.local_size_1 << ";" << std::endl;
|
||||
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
Reference in New Issue
Block a user