Kernels/REDUCE: added temporary workspace information

This commit is contained in:
Philippe Tillet
2015-12-18 18:10:05 -05:00
parent db8ef4b933
commit acd460402d
5 changed files with 20 additions and 2 deletions

View File

@@ -22,6 +22,7 @@ class reduce_1d : public base_impl<reduce_1d, reduce_1d_parameters>
private: private:
unsigned int lmem_usage(math_expression const & expressions) const; unsigned int lmem_usage(math_expression const & expressions) const;
int is_invalid_impl(driver::Device const &, math_expression const &) const; int is_invalid_impl(driver::Device const &, math_expression const &) const;
unsigned int temporary_workspace(math_expression const & expressions) const;
inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_reduce_1d*> exprs, inline void reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_reduce_1d*> exprs,
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const; std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const;
std::string generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mapping) const; std::string generate_impl(std::string const & suffix, math_expression const & expressions, driver::Device const & device, mapping_type const & mapping) const;

View File

@@ -31,8 +31,9 @@ protected:
}; };
reduce_2d(reduce_2d::parameters_type const & , reduce_1d_type, binding_policy_t); reduce_2d(reduce_2d::parameters_type const & , reduce_1d_type, binding_policy_t);
private: private:
virtual int is_invalid_impl(driver::Device const &, math_expression const &) const; int is_invalid_impl(driver::Device const &, math_expression const &) const;
unsigned int lmem_usage(math_expression const &) const; unsigned int lmem_usage(math_expression const &) const;
unsigned int temporary_workspace(math_expression const & expressions) const;
std::string generate_impl(std::string const & suffix, math_expression const &, driver::Device const & device, mapping_type const &) const; std::string generate_impl(std::string const & suffix, math_expression const &, driver::Device const & device, mapping_type const &) const;
public: public:
virtual std::vector<int_t> input_sizes(math_expression const & expressions) const; virtual std::vector<int_t> input_sizes(math_expression const & expressions) const;

View File

@@ -33,6 +33,13 @@ int reduce_1d::is_invalid_impl(driver::Device const &, math_expression const &)
return TEMPLATE_VALID; return TEMPLATE_VALID;
} }
unsigned int reduce_1d::temporary_workspace(math_expression const &) const
{
if(p_.num_groups > 1)
return p_.num_groups;
return 0;
}
inline void reduce_1d::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_reduce_1d*> exprs, inline void reduce_1d::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_reduce_1d*> exprs,
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const
{ {

View File

@@ -34,6 +34,15 @@ unsigned int reduce_2d::lmem_usage(const math_expression&) const
return (p_.local_size_0+1)*p_.local_size_1; return (p_.local_size_0+1)*p_.local_size_1;
} }
unsigned int reduce_2d::temporary_workspace(math_expression const & expressions) const
{
std::vector<int_t> MN = input_sizes(expressions);
int_t M = MN[0];
if(p_.num_groups_0 > 1)
return M*p_.num_groups_0;
return 0;
}
std::string reduce_2d::generate_impl(std::string const & suffix, math_expression const & expression, driver::Device const & device, mapping_type const & mapping) const std::string reduce_2d::generate_impl(std::string const & suffix, math_expression const & expression, driver::Device const & device, mapping_type const & mapping) const
{ {
using tools::to_string; using tools::to_string;

View File

@@ -73,7 +73,7 @@ def main():
libraries += ['gnustl_shared'] libraries += ['gnustl_shared']
#Source files #Source files
src = 'src/lib/wrap/cublas.cpp src/lib/wrap/clBLAS.cpp src/lib/exception/operation_not_supported.cpp src/lib/exception/unknown_datatype.cpp src/lib/value_scalar.cpp src/lib/array.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/execute.cpp src/lib/kernels/binder.cpp src/lib/kernels/keywords.cpp src/lib/kernels/parse.cpp src/lib/kernels/templates/elementwise_1d.cpp src/lib/kernels/templates/matrix_product.cpp src/lib/kernels/templates/reduce_2d.cpp src/lib/kernels/templates/reduce_1d.cpp src/lib/kernels/templates/base.cpp src/lib/kernels/templates/elementwise_2d.cpp src/lib/kernels/mapped_object.cpp src/lib/kernels/stream.cpp src/lib/driver/dispatch.cpp src/lib/driver/kernel.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/buffer.cpp src/lib/driver/event.cpp src/lib/driver/ndrange.cpp src/lib/driver/device.cpp src/lib/driver/program_cache.cpp src/lib/driver/check.cpp src/lib/driver/command_queue.cpp src/lib/driver/handle.cpp src/lib/driver/context.cpp src/lib/driver/program.cpp src/lib/profiles/predictors/random_forest.cpp src/lib/profiles/profiles.cpp src/lib/profiles/presets.cpp '.split() + [os.path.join('src', 'bind', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'kernels.cpp', 'exceptions.cpp']] src = 'src/lib/symbolic/preset.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/driver/backend.cpp src/lib/driver/device.cpp src/lib/driver/kernel.cpp src/lib/driver/buffer.cpp src/lib/driver/platform.cpp src/lib/driver/check.cpp src/lib/driver/program.cpp src/lib/driver/command_queue.cpp src/lib/driver/dispatch.cpp src/lib/driver/program_cache.cpp src/lib/driver/context.cpp src/lib/driver/event.cpp src/lib/driver/ndrange.cpp src/lib/driver/handle.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/profiles/presets.cpp src/lib/profiles/profiles.cpp src/lib/profiles/predictors/random_forest.cpp src/lib/kernels/templates/reduce_2d.cpp src/lib/kernels/templates/elementwise_2d.cpp src/lib/kernels/templates/elementwise_1d.cpp src/lib/kernels/templates/reduce_1d.cpp src/lib/kernels/templates/matrix_product.cpp src/lib/kernels/templates/base.cpp src/lib/kernels/mapped_object.cpp src/lib/kernels/stream.cpp src/lib/kernels/parse.cpp src/lib/kernels/keywords.cpp src/lib/kernels/binder.cpp src/lib/wrap/clBLAS.cpp src/lib/wrap/cublas.cpp '.split() + [os.path.join('src', 'bind', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'kernels.cpp', 'exceptions.cpp']]
boostsrc = 'external/boost/libs/' boostsrc = 'external/boost/libs/'
for s in ['numpy','python','smart_ptr','system','thread']: for s in ['numpy','python','smart_ptr','system','thread']:
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x] src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]