Feature: Merged kernel-fusion branch

* Fuses multiple AXPY kernel
* Possibility to add thread-wise for loops in AXPY-like kernels
This commit is contained in:
Philippe Tillet
2015-09-30 15:31:41 -04:00
parent 149441b9e2
commit feeb1e9862
64 changed files with 10047 additions and 1119 deletions

View File

@@ -3,6 +3,7 @@
#include <string>
#include "isaac/array.h"
#include "isaac/tuple.h"
#include "isaac/kernels/keywords.h"
#include "isaac/kernels/templates/axpy.h"
#include "isaac/kernels/templates/dot.h"
@@ -27,17 +28,16 @@ base::parameters_type::parameters_type(unsigned int _simd_width, int_t _local_si
{ }
bool base::requires_fallback(expressions_tuple const & expressions)
bool base::requires_fallback(math_expression const & expression)
{
for (const auto & elem : expressions.data())
for(array_expression::container_type::const_iterator itt = (elem)->tree().begin(); itt != (elem)->tree().end() ; ++itt)
if( (itt->lhs.subtype==DENSE_ARRAY_TYPE && (std::max(itt->lhs.array->stride()[0], itt->lhs.array->stride()[1])>1 || std::max(itt->lhs.array->start()[0],itt->lhs.array->start()[1])>0))
|| (itt->rhs.subtype==DENSE_ARRAY_TYPE && (std::max(itt->rhs.array->stride()[0], itt->rhs.array->stride()[1])>1 || std::max(itt->rhs.array->start()[0],itt->rhs.array->start()[1])>0)))
return true;
for(math_expression::node const & node: expression.tree())
if( (node.lhs.subtype==DENSE_ARRAY_TYPE && (std::max(node.lhs.array->stride()[0], node.lhs.array->stride()[1])>1 || std::max(node.lhs.array->start()[0],node.lhs.array->start()[1])>0))
|| (node.rhs.subtype==DENSE_ARRAY_TYPE && (std::max(node.rhs.array->stride()[0], node.rhs.array->stride()[1])>1 || std::max(node.rhs.array->start()[0],node.rhs.array->start()[1])>0)))
return true;
return false;
}
int_t base::vector_size(array_expression::node const & node)
int_t base::vector_size(math_expression::node const & node)
{
if (node.op.type==OPERATOR_MATRIX_DIAG_TYPE)
return std::min<int_t>(node.lhs.array->shape()[0], node.lhs.array->shape()[1]);
@@ -50,7 +50,7 @@ int_t base::vector_size(array_expression::node const & node)
}
std::pair<int_t, int_t> base::matrix_size(array_expression::node const & node)
std::pair<int_t, int_t> base::matrix_size(math_expression::container_type const & tree, math_expression::node const & node)
{
if (node.op.type==OPERATOR_VDIAG_TYPE)
{
@@ -58,7 +58,12 @@ std::pair<int_t, int_t> base::matrix_size(array_expression::node const & node)
return std::make_pair(size,size);
}
else if(node.op.type==OPERATOR_REPEAT_TYPE)
return std::make_pair(node.lhs.array->shape()[0]*node.rhs.tuple.rep1, node.lhs.array->shape()[1]*node.rhs.tuple.rep2);
{
size_t rep0 = tuple_get(tree, node.rhs.node_index, 0);
size_t rep1 = tuple_get(tree, node.rhs.node_index, 1);
std::cout << rep0 << " " << rep1 << std::endl;
return std::make_pair(node.lhs.array->shape()[0]*rep0, node.lhs.array->shape()[1]*rep1);
}
else
return std::make_pair(node.lhs.array->shape()[0],node.lhs.array->shape()[1]);
}
@@ -67,43 +72,39 @@ std::pair<int_t, int_t> base::matrix_size(array_expression::node const & node)
base::base(binding_policy_t binding_policy) : binding_policy_(binding_policy)
{}
unsigned int base::lmem_usage(expressions_tuple const &) const
unsigned int base::lmem_usage(math_expression const &) const
{ return 0; }
unsigned int base::registers_usage(expressions_tuple const &) const
unsigned int base::registers_usage(math_expression const &) const
{ return 0; }
unsigned int base::temporary_workspace(expressions_tuple const &) const
unsigned int base::temporary_workspace(math_expression const &) const
{ return 0; }
base::~base()
{
}
std::string base::generate(std::string const & suffix, expressions_tuple const & expressions, driver::Device const & device)
std::string base::generate(std::string const & suffix, math_expression const & expression, driver::Device const & device)
{
expressions_tuple::data_type::const_iterator sit;
std::vector<mapping_type>::iterator mit;
int err = is_invalid(expressions, device);
int err = is_invalid(expression, device);
if(err != 0)
throw operation_not_supported_exception("The supplied parameters for this template are invalid : err " + tools::to_string(err));
//Create mapping
std::vector<mapping_type> mappings(expressions.data().size());
mapping_type mapping;
std::unique_ptr<symbolic_binder> binder;
if (binding_policy_==BIND_TO_HANDLE)
binder.reset(new bind_to_handle());
if (binding_policy_==BIND_SEQUENTIAL)
binder.reset(new bind_sequential());
else
binder.reset(new bind_all_unique());
binder.reset(new bind_independent());
for (mit = mappings.begin(), sit = expressions.data().begin(); sit != expressions.data().end(); ++sit, ++mit)
traverse(**sit, (*sit)->root(), map_functor(*binder,*mit,device), true);
return generate_impl(suffix, expressions, device, mappings);
traverse(expression, expression.root(), map_functor(*binder, mapping, device), true);
return generate_impl(suffix, expression, device, mapping);
}
template<class TType, class PType>
int base_impl<TType, PType>::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
int base_impl<TType, PType>::is_invalid_impl(driver::Device const &, math_expression const &) const
{ return TEMPLATE_VALID; }
template<class TType, class PType>
@@ -123,7 +124,7 @@ std::shared_ptr<base> base_impl<TType, PType>::clone() const
{ return std::shared_ptr<base>(new TType(*dynamic_cast<TType const *>(this))); }
template<class TType, class PType>
int base_impl<TType, PType>::is_invalid(expressions_tuple const & expressions, driver::Device const & device) const
int base_impl<TType, PType>::is_invalid(math_expression const & expressions, driver::Device const & device) const
{
//Query device informations
size_t lmem_available = device.local_mem_size();