diff --git a/bench/blas.cpp b/bench/blas.cpp index 5ec5820ca..090e29a4f 100644 --- a/bench/blas.cpp +++ b/bench/blas.cpp @@ -265,7 +265,6 @@ void bench(isc::numeric_type dtype, std::string operation) #endif std::cout << std::endl; } - std::cout << "\n\n" << std::flush; } if(operation.substr(0, 4)=="gemv") @@ -334,6 +333,8 @@ void bench(isc::numeric_type dtype, std::string operation) MNKs.push_back(std::make_tuple('N','N',169,384,2304)); MNKs.push_back(std::make_tuple('N','N',169,192,1728)); MNKs.push_back(std::make_tuple('N','N',169,128,1728)); + MNKs.push_back(std::make_tuple('N','T',256,4096,9216)); + //AlexNet (Backward) MNKs.push_back(std::make_tuple('T','N',1728,128,169)); MNKs.push_back(std::make_tuple('T','N',1728,192,169)); diff --git a/include/isaac/kernels/templates/base.h b/include/isaac/kernels/templates/base.h index b09e64041..746d4c9fe 100644 --- a/include/isaac/kernels/templates/base.h +++ b/include/isaac/kernels/templates/base.h @@ -67,6 +67,7 @@ private: virtual std::string generate_impl(std::string const & suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector const & mapping) const = 0; public: base(binding_policy_t binding_policy); + virtual unsigned int temporary_workspace(expressions_tuple const &) const; virtual unsigned int lmem_usage(expressions_tuple const &) const; virtual unsigned int registers_usage(expressions_tuple const &) const; virtual std::vector input_sizes(expressions_tuple const & expressions) const = 0; diff --git a/include/isaac/kernels/templates/gemm.h b/include/isaac/kernels/templates/gemm.h index c176743bb..f6d2fbeef 100644 --- a/include/isaac/kernels/templates/gemm.h +++ b/include/isaac/kernels/templates/gemm.h @@ -42,6 +42,7 @@ struct gemm_parameters : public base::parameters_type class gemm : public base_impl { private: + unsigned int temporary_workspace(expressions_tuple const & expressions) const; unsigned int lmem_usage(expressions_tuple const & expressions) const; unsigned int registers_usage(expressions_tuple const & expressions) const; int is_invalid_impl(driver::Device const &, expressions_tuple const &) const; diff --git a/lib/kernels/templates/base.cpp b/lib/kernels/templates/base.cpp index 0aeb34a25..7032c481a 100644 --- a/lib/kernels/templates/base.cpp +++ b/lib/kernels/templates/base.cpp @@ -73,6 +73,9 @@ unsigned int base::lmem_usage(expressions_tuple const &) const unsigned int base::registers_usage(expressions_tuple const &) const { return 0; } +unsigned int base::temporary_workspace(expressions_tuple const &) const +{ return 0; } + base::~base() { } @@ -81,7 +84,8 @@ std::string base::generate(std::string const & suffix, expressions_tuple const & expressions_tuple::data_type::const_iterator sit; std::vector::iterator mit; - if(int err = is_invalid(expressions, device)) + int err = is_invalid(expressions, device); + if(err != 0 && err != TEMPLATE_TEMPORARY_TOO_LARGE) throw operation_not_supported_exception("The supplied parameters for this template are invalid : err " + tools::to_string(err)); //Create mapping @@ -142,6 +146,10 @@ int base_impl::is_invalid(expressions_tuple const & expressions, d if (p_.simd_width!=1 && p_.simd_width!=2 && p_.simd_width!=3 && p_.simd_width!=4) return TEMPLATE_INVALID_SIMD_WIDTH; + //Temporary workspace + if(temporary_workspace(expressions) > 2e6) + return TEMPLATE_TEMPORARY_TOO_LARGE; + return is_invalid_impl(device, expressions); } diff --git a/lib/kernels/templates/gemm.cpp b/lib/kernels/templates/gemm.cpp index d1c2ad6a1..dcf057a58 100644 --- a/lib/kernels/templates/gemm.cpp +++ b/lib/kernels/templates/gemm.cpp @@ -46,16 +46,20 @@ gemm_parameters::gemm_parameters(unsigned int simd_width return N*size_of(numeric_t); } - int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const + unsigned int gemm::temporary_workspace(expressions_tuple const & expressions) const { - std::vector MNK = input_sizes(expressions); - int_t M = MNK[0]; int_t N = MNK[1]; + std::vector MNK = input_sizes(expressions); + int_t M = MNK[0]; int_t N = MNK[1]; + if(p_.depth > 1) + return M*N*p_.depth; + return 0; + } + int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const &) const + { if(p_.A_fetching_policy!=FETCH_FROM_LOCAL || p_.B_fetching_policy!=FETCH_FROM_LOCAL) return TEMPLATE_INVALID_FETCHING_POLICY_TYPE; - if(p_.depth > 1 && M*N*p_.depth > 2e6) - return TEMPLATE_TEMPORARY_TOO_LARGE; if ((p_.mS % p_.simd_width) > 0 || (p_.nS % p_.simd_width) > 0) return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE; @@ -66,8 +70,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width if ( p_.kS % p_.kL == 0) return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL; - if (p_.A_fetching_policy==FETCH_FROM_LOCAL || p_.B_fetching_policy==FETCH_FROM_LOCAL) - { + if (p_.A_fetching_policy==FETCH_FROM_LOCAL || p_.B_fetching_policy==FETCH_FROM_LOCAL){ if ((p_.local_fetch_0*p_.local_fetch_1) !=(p_.local_size_0*p_.local_size_1)) return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT; } diff --git a/lib/model/model.cpp b/lib/model/model.cpp index 19e827c43..f8eeed2a0 100644 --- a/lib/model/model.cpp +++ b/lib/model/model.cpp @@ -124,6 +124,7 @@ void model::execute(controller const & expr) } //Execution +// std::cout << std::endl << "Label: " << label << std::endl; return templates_[label]->enqueue(queue_, program, tools::to_string(label), *fallback_, expr); } diff --git a/lib/wrap/clBLAS.cpp b/lib/wrap/clBLAS.cpp index bf2dc12ec..54600d8c1 100644 --- a/lib/wrap/clBLAS.cpp +++ b/lib/wrap/clBLAS.cpp @@ -176,6 +176,7 @@ extern "C" cl_uint numCommandQueues, cl_command_queue *commandQueues,\ cl_uint numEventsInWaitList, const cl_event *eventWaitList, cl_event *events)\ {\ + /*std::cout << transA << " " <