Feature: Merged kernel-fusion branch
* Fuses multiple AXPY kernel * Possibility to add thread-wise for loops in AXPY-like kernels
This commit is contained in:
@@ -85,6 +85,8 @@ namespace isaac
|
||||
else if(left == AXPY_TYPE || right == AXPY_TYPE) return op.type==OPERATOR_OUTER_PROD_TYPE?GER_TYPE:AXPY_TYPE;
|
||||
else if(left == GER_TYPE || right == GER_TYPE) return GER_TYPE;
|
||||
else if(is_mmprod(left) || is_mmprod(right)) return GER_TYPE;
|
||||
else if(right == INVALID_EXPRESSION_TYPE) return left;
|
||||
else if(left == INVALID_EXPRESSION_TYPE) return right;
|
||||
throw;
|
||||
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
|
||||
return DOT_TYPE;
|
||||
@@ -103,12 +105,12 @@ namespace isaac
|
||||
}
|
||||
|
||||
/** @brief Parses the breakpoints for a given expression tree */
|
||||
static void parse(array_expression::container_type&array, size_t idx,
|
||||
static void parse(math_expression::container_type&array, size_t idx,
|
||||
breakpoints_t & breakpoints,
|
||||
expression_type & final_type,
|
||||
bool is_first = true)
|
||||
{
|
||||
array_expression::node & node = array[idx];
|
||||
math_expression::node & node = array[idx];
|
||||
|
||||
//Left
|
||||
expression_type type_left = INVALID_EXPRESSION_TYPE;
|
||||
@@ -144,17 +146,17 @@ namespace isaac
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief Executes a array_expression on the given models map*/
|
||||
void execute(controller<array_expression> const & c, profiles::map_type & profiles)
|
||||
/** @brief Executes a math_expression on the given models map*/
|
||||
void execute(execution_handler const & c, profiles::map_type & profiles)
|
||||
{
|
||||
array_expression expression = c.x();
|
||||
math_expression expression = c.x();
|
||||
driver::Context const & context = expression.context();
|
||||
size_t rootidx = expression.root();
|
||||
array_expression::container_type & tree = const_cast<array_expression::container_type &>(expression.tree());
|
||||
array_expression::node root_save = tree[rootidx];
|
||||
math_expression::container_type & tree = const_cast<math_expression::container_type &>(expression.tree());
|
||||
math_expression::node root_save = tree[rootidx];
|
||||
|
||||
//Todo: technically the datatype should be per temporary
|
||||
numeric_type dtype = root_save.lhs.dtype;
|
||||
numeric_type dtype = expression.dtype();
|
||||
|
||||
expression_type final_type;
|
||||
//GEMM
|
||||
@@ -169,7 +171,7 @@ namespace isaac
|
||||
|
||||
//Init
|
||||
expression_type current_type;
|
||||
if(root_save.lhs.array->nshape()<=1)
|
||||
if(expression.nshape()<=1)
|
||||
current_type=AXPY_TYPE;
|
||||
else
|
||||
current_type=GER_TYPE;
|
||||
@@ -183,8 +185,8 @@ namespace isaac
|
||||
for(detail::breakpoints_t::iterator it = breakpoints.begin() ; it != breakpoints.end() ; ++it)
|
||||
{
|
||||
std::shared_ptr<profiles::value_type> const & profile = profiles[std::make_pair(it->first, dtype)];
|
||||
array_expression::node const & node = tree[it->second->node_index];
|
||||
array_expression::node const & lmost = lhs_most(tree, node);
|
||||
math_expression::node const & node = tree[it->second->node_index];
|
||||
math_expression::node const & lmost = lhs_most(tree, node);
|
||||
|
||||
//Creates temporary
|
||||
std::shared_ptr<array> tmp;
|
||||
@@ -211,16 +213,21 @@ namespace isaac
|
||||
tree[rootidx].rhs.type_family = it->second->type_family;
|
||||
|
||||
//Execute
|
||||
profile->execute(controller<expressions_tuple>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
||||
profile->execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
||||
tree[rootidx] = root_save;
|
||||
|
||||
//Incorporates the temporary within the array_expression
|
||||
//Incorporates the temporary within the math_expression
|
||||
fill(*it->second, (array&)*tmp);
|
||||
}
|
||||
}
|
||||
|
||||
/*-----Compute final expression-----*/
|
||||
profiles[std::make_pair(final_type, dtype)]->execute(controller<expressions_tuple>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
||||
profiles[std::make_pair(final_type, dtype)]->execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
||||
}
|
||||
|
||||
void execute(execution_handler const & c)
|
||||
{
|
||||
execute(c, isaac::profiles::get(c.execution_options().queue(c.x().context())));
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user