diff --git a/include/isaac/kernels/templates/reduce_1d.h b/include/isaac/kernels/templates/reduce_1d.h index 13ae72e06..3dc8b465a 100644 --- a/include/isaac/kernels/templates/reduce_1d.h +++ b/include/isaac/kernels/templates/reduce_1d.h @@ -22,6 +22,7 @@ class reduce_1d : public base_impl private: unsigned int lmem_usage(math_expression const & expressions) const; int is_invalid_impl(driver::Device const &, math_expression const &) const; + unsigned int temporary_workspace(math_expression const & expressions) const; inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector exprs, std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const; std::string generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mapping) const; diff --git a/include/isaac/kernels/templates/reduce_2d.h b/include/isaac/kernels/templates/reduce_2d.h index dc313b4a5..81b27b3b9 100644 --- a/include/isaac/kernels/templates/reduce_2d.h +++ b/include/isaac/kernels/templates/reduce_2d.h @@ -31,8 +31,9 @@ protected: }; reduce_2d(reduce_2d::parameters_type const & , reduce_1d_type, binding_policy_t); private: - virtual int is_invalid_impl(driver::Device const &, math_expression const &) const; + int is_invalid_impl(driver::Device const &, math_expression const &) const; unsigned int lmem_usage(math_expression const &) const; + unsigned int temporary_workspace(math_expression const & expressions) const; std::string generate_impl(std::string const & suffix, math_expression const &, driver::Device const & device, mapping_type const &) const; public: virtual std::vector input_sizes(math_expression const & expressions) const; diff --git a/lib/kernels/templates/reduce_1d.cpp b/lib/kernels/templates/reduce_1d.cpp index 310991b95..7cb20ad47 100644 --- a/lib/kernels/templates/reduce_1d.cpp +++ b/lib/kernels/templates/reduce_1d.cpp @@ -33,6 +33,13 @@ int reduce_1d::is_invalid_impl(driver::Device const &, math_expression const &) return TEMPLATE_VALID; } +unsigned int reduce_1d::temporary_workspace(math_expression const &) const +{ + if(p_.num_groups > 1) + return p_.num_groups; + return 0; +} + inline void reduce_1d::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector exprs, std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const { diff --git a/lib/kernels/templates/reduce_2d.cpp b/lib/kernels/templates/reduce_2d.cpp index 16bc87198..b824c2069 100644 --- a/lib/kernels/templates/reduce_2d.cpp +++ b/lib/kernels/templates/reduce_2d.cpp @@ -34,6 +34,15 @@ unsigned int reduce_2d::lmem_usage(const math_expression&) const return (p_.local_size_0+1)*p_.local_size_1; } +unsigned int reduce_2d::temporary_workspace(math_expression const & expressions) const +{ + std::vector MN = input_sizes(expressions); + int_t M = MN[0]; + if(p_.num_groups_0 > 1) + return M*p_.num_groups_0; + return 0; +} + std::string reduce_2d::generate_impl(std::string const & suffix, math_expression const & expression, driver::Device const & device, mapping_type const & mapping) const { using tools::to_string; diff --git a/python/setup.py b/python/setup.py index 390f258a4..ba295b97d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -73,7 +73,7 @@ def main(): libraries += ['gnustl_shared'] #Source files - src = 'src/lib/wrap/cublas.cpp src/lib/wrap/clBLAS.cpp src/lib/exception/operation_not_supported.cpp src/lib/exception/unknown_datatype.cpp src/lib/value_scalar.cpp src/lib/array.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/execute.cpp src/lib/kernels/binder.cpp src/lib/kernels/keywords.cpp src/lib/kernels/parse.cpp src/lib/kernels/templates/elementwise_1d.cpp src/lib/kernels/templates/matrix_product.cpp src/lib/kernels/templates/reduce_2d.cpp src/lib/kernels/templates/reduce_1d.cpp src/lib/kernels/templates/base.cpp src/lib/kernels/templates/elementwise_2d.cpp src/lib/kernels/mapped_object.cpp src/lib/kernels/stream.cpp src/lib/driver/dispatch.cpp src/lib/driver/kernel.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/buffer.cpp src/lib/driver/event.cpp src/lib/driver/ndrange.cpp src/lib/driver/device.cpp src/lib/driver/program_cache.cpp src/lib/driver/check.cpp src/lib/driver/command_queue.cpp src/lib/driver/handle.cpp src/lib/driver/context.cpp src/lib/driver/program.cpp src/lib/profiles/predictors/random_forest.cpp src/lib/profiles/profiles.cpp src/lib/profiles/presets.cpp '.split() + [os.path.join('src', 'bind', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'kernels.cpp', 'exceptions.cpp']] + src = 'src/lib/symbolic/preset.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/driver/backend.cpp src/lib/driver/device.cpp src/lib/driver/kernel.cpp src/lib/driver/buffer.cpp src/lib/driver/platform.cpp src/lib/driver/check.cpp src/lib/driver/program.cpp src/lib/driver/command_queue.cpp src/lib/driver/dispatch.cpp src/lib/driver/program_cache.cpp src/lib/driver/context.cpp src/lib/driver/event.cpp src/lib/driver/ndrange.cpp src/lib/driver/handle.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/profiles/presets.cpp src/lib/profiles/profiles.cpp src/lib/profiles/predictors/random_forest.cpp src/lib/kernels/templates/reduce_2d.cpp src/lib/kernels/templates/elementwise_2d.cpp src/lib/kernels/templates/elementwise_1d.cpp src/lib/kernels/templates/reduce_1d.cpp src/lib/kernels/templates/matrix_product.cpp src/lib/kernels/templates/base.cpp src/lib/kernels/mapped_object.cpp src/lib/kernels/stream.cpp src/lib/kernels/parse.cpp src/lib/kernels/keywords.cpp src/lib/kernels/binder.cpp src/lib/wrap/clBLAS.cpp src/lib/wrap/cublas.cpp '.split() + [os.path.join('src', 'bind', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'kernels.cpp', 'exceptions.cpp']] boostsrc = 'external/boost/libs/' for s in ['numpy','python','smart_ptr','system','thread']: src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]