Feature: Merged kernel-fusion branch
* Fuses multiple AXPY kernel * Possibility to add thread-wise for loops in AXPY-like kernels
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include <string>
|
||||
|
||||
#include "isaac/array.h"
|
||||
#include "isaac/tuple.h"
|
||||
#include "isaac/kernels/keywords.h"
|
||||
#include "isaac/kernels/templates/axpy.h"
|
||||
#include "isaac/kernels/templates/dot.h"
|
||||
@@ -27,17 +28,16 @@ base::parameters_type::parameters_type(unsigned int _simd_width, int_t _local_si
|
||||
{ }
|
||||
|
||||
|
||||
bool base::requires_fallback(expressions_tuple const & expressions)
|
||||
bool base::requires_fallback(math_expression const & expression)
|
||||
{
|
||||
for (const auto & elem : expressions.data())
|
||||
for(array_expression::container_type::const_iterator itt = (elem)->tree().begin(); itt != (elem)->tree().end() ; ++itt)
|
||||
if( (itt->lhs.subtype==DENSE_ARRAY_TYPE && (std::max(itt->lhs.array->stride()[0], itt->lhs.array->stride()[1])>1 || std::max(itt->lhs.array->start()[0],itt->lhs.array->start()[1])>0))
|
||||
|| (itt->rhs.subtype==DENSE_ARRAY_TYPE && (std::max(itt->rhs.array->stride()[0], itt->rhs.array->stride()[1])>1 || std::max(itt->rhs.array->start()[0],itt->rhs.array->start()[1])>0)))
|
||||
return true;
|
||||
for(math_expression::node const & node: expression.tree())
|
||||
if( (node.lhs.subtype==DENSE_ARRAY_TYPE && (std::max(node.lhs.array->stride()[0], node.lhs.array->stride()[1])>1 || std::max(node.lhs.array->start()[0],node.lhs.array->start()[1])>0))
|
||||
|| (node.rhs.subtype==DENSE_ARRAY_TYPE && (std::max(node.rhs.array->stride()[0], node.rhs.array->stride()[1])>1 || std::max(node.rhs.array->start()[0],node.rhs.array->start()[1])>0)))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
int_t base::vector_size(array_expression::node const & node)
|
||||
int_t base::vector_size(math_expression::node const & node)
|
||||
{
|
||||
if (node.op.type==OPERATOR_MATRIX_DIAG_TYPE)
|
||||
return std::min<int_t>(node.lhs.array->shape()[0], node.lhs.array->shape()[1]);
|
||||
@@ -50,7 +50,7 @@ int_t base::vector_size(array_expression::node const & node)
|
||||
|
||||
}
|
||||
|
||||
std::pair<int_t, int_t> base::matrix_size(array_expression::node const & node)
|
||||
std::pair<int_t, int_t> base::matrix_size(math_expression::container_type const & tree, math_expression::node const & node)
|
||||
{
|
||||
if (node.op.type==OPERATOR_VDIAG_TYPE)
|
||||
{
|
||||
@@ -58,7 +58,12 @@ std::pair<int_t, int_t> base::matrix_size(array_expression::node const & node)
|
||||
return std::make_pair(size,size);
|
||||
}
|
||||
else if(node.op.type==OPERATOR_REPEAT_TYPE)
|
||||
return std::make_pair(node.lhs.array->shape()[0]*node.rhs.tuple.rep1, node.lhs.array->shape()[1]*node.rhs.tuple.rep2);
|
||||
{
|
||||
size_t rep0 = tuple_get(tree, node.rhs.node_index, 0);
|
||||
size_t rep1 = tuple_get(tree, node.rhs.node_index, 1);
|
||||
std::cout << rep0 << " " << rep1 << std::endl;
|
||||
return std::make_pair(node.lhs.array->shape()[0]*rep0, node.lhs.array->shape()[1]*rep1);
|
||||
}
|
||||
else
|
||||
return std::make_pair(node.lhs.array->shape()[0],node.lhs.array->shape()[1]);
|
||||
}
|
||||
@@ -67,43 +72,39 @@ std::pair<int_t, int_t> base::matrix_size(array_expression::node const & node)
|
||||
base::base(binding_policy_t binding_policy) : binding_policy_(binding_policy)
|
||||
{}
|
||||
|
||||
unsigned int base::lmem_usage(expressions_tuple const &) const
|
||||
unsigned int base::lmem_usage(math_expression const &) const
|
||||
{ return 0; }
|
||||
|
||||
unsigned int base::registers_usage(expressions_tuple const &) const
|
||||
unsigned int base::registers_usage(math_expression const &) const
|
||||
{ return 0; }
|
||||
|
||||
unsigned int base::temporary_workspace(expressions_tuple const &) const
|
||||
unsigned int base::temporary_workspace(math_expression const &) const
|
||||
{ return 0; }
|
||||
|
||||
base::~base()
|
||||
{
|
||||
}
|
||||
|
||||
std::string base::generate(std::string const & suffix, expressions_tuple const & expressions, driver::Device const & device)
|
||||
std::string base::generate(std::string const & suffix, math_expression const & expression, driver::Device const & device)
|
||||
{
|
||||
expressions_tuple::data_type::const_iterator sit;
|
||||
std::vector<mapping_type>::iterator mit;
|
||||
int err = is_invalid(expressions, device);
|
||||
int err = is_invalid(expression, device);
|
||||
if(err != 0)
|
||||
throw operation_not_supported_exception("The supplied parameters for this template are invalid : err " + tools::to_string(err));
|
||||
|
||||
//Create mapping
|
||||
std::vector<mapping_type> mappings(expressions.data().size());
|
||||
mapping_type mapping;
|
||||
std::unique_ptr<symbolic_binder> binder;
|
||||
if (binding_policy_==BIND_TO_HANDLE)
|
||||
binder.reset(new bind_to_handle());
|
||||
if (binding_policy_==BIND_SEQUENTIAL)
|
||||
binder.reset(new bind_sequential());
|
||||
else
|
||||
binder.reset(new bind_all_unique());
|
||||
binder.reset(new bind_independent());
|
||||
|
||||
for (mit = mappings.begin(), sit = expressions.data().begin(); sit != expressions.data().end(); ++sit, ++mit)
|
||||
traverse(**sit, (*sit)->root(), map_functor(*binder,*mit,device), true);
|
||||
|
||||
return generate_impl(suffix, expressions, device, mappings);
|
||||
traverse(expression, expression.root(), map_functor(*binder, mapping, device), true);
|
||||
return generate_impl(suffix, expression, device, mapping);
|
||||
}
|
||||
|
||||
template<class TType, class PType>
|
||||
int base_impl<TType, PType>::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||
int base_impl<TType, PType>::is_invalid_impl(driver::Device const &, math_expression const &) const
|
||||
{ return TEMPLATE_VALID; }
|
||||
|
||||
template<class TType, class PType>
|
||||
@@ -123,7 +124,7 @@ std::shared_ptr<base> base_impl<TType, PType>::clone() const
|
||||
{ return std::shared_ptr<base>(new TType(*dynamic_cast<TType const *>(this))); }
|
||||
|
||||
template<class TType, class PType>
|
||||
int base_impl<TType, PType>::is_invalid(expressions_tuple const & expressions, driver::Device const & device) const
|
||||
int base_impl<TType, PType>::is_invalid(math_expression const & expressions, driver::Device const & device) const
|
||||
{
|
||||
//Query device informations
|
||||
size_t lmem_available = device.local_mem_size();
|
||||
|
Reference in New Issue
Block a user