Backend: Now not creating a temporary upon C = alpha*dot(op(A), op(B)) + beta*C
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <CL/cl.hpp>
|
||||
#include "isaac/model/model.h"
|
||||
#include "isaac/symbolic/expression.h"
|
||||
#include "isaac/symbolic/preset.h"
|
||||
|
||||
namespace isaac
|
||||
{
|
||||
@@ -157,7 +158,6 @@ namespace isaac
|
||||
parse(array, node.rhs.node_index, next_type, breakpoints, final_type, false);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/** @brief Executes a array_expression on the given models map*/
|
||||
@@ -172,61 +172,70 @@ namespace isaac
|
||||
//Todo: technically the datatype should be per temporary
|
||||
numeric_type dtype = root_save.lhs.dtype;
|
||||
|
||||
detail::breakpoints_t breakpoints;
|
||||
breakpoints.reserve(8);
|
||||
|
||||
//Init
|
||||
expression_type current_type;
|
||||
if(root_save.lhs.array->nshape()==0)
|
||||
current_type = SCALAR_AXPY_TYPE;
|
||||
else if(root_save.lhs.array->nshape()==1)
|
||||
current_type=VECTOR_AXPY_TYPE;
|
||||
expression_type final_type;
|
||||
//GEMM
|
||||
if(symbolic::preset::gemm::args args = symbolic::preset::gemm::check(tree, rootidx)){
|
||||
final_type = args.type;
|
||||
}
|
||||
//Default
|
||||
else
|
||||
current_type=MATRIX_AXPY_TYPE;
|
||||
expression_type final_type = current_type;
|
||||
|
||||
/*----Parse required temporaries-----*/
|
||||
detail::parse(tree, rootidx, current_type, breakpoints, final_type);
|
||||
std::vector<tools::shared_ptr<array> > temporaries_;
|
||||
|
||||
/*----Compute required temporaries----*/
|
||||
for(detail::breakpoints_t::reverse_iterator rit = breakpoints.rbegin() ; rit != breakpoints.rend() ; ++rit)
|
||||
{
|
||||
tools::shared_ptr<model> const & pmodel = models[std::make_pair(rit->first, dtype)];
|
||||
array_expression::node const & node = tree[rit->second->node_index];
|
||||
array_expression::node const & lmost = lhs_most(tree, node);
|
||||
detail::breakpoints_t breakpoints;
|
||||
breakpoints.reserve(8);
|
||||
|
||||
//Creates temporary
|
||||
tools::shared_ptr<array> tmp;
|
||||
switch(rit->first){
|
||||
case SCALAR_AXPY_TYPE:
|
||||
case REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
|
||||
//Init
|
||||
expression_type current_type;
|
||||
if(root_save.lhs.array->nshape()==0)
|
||||
current_type = SCALAR_AXPY_TYPE;
|
||||
else if(root_save.lhs.array->nshape()==1)
|
||||
current_type=VECTOR_AXPY_TYPE;
|
||||
else
|
||||
current_type=MATRIX_AXPY_TYPE;
|
||||
final_type = current_type;
|
||||
|
||||
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||
/*----Parse required temporaries-----*/
|
||||
detail::parse(tree, rootidx, current_type, breakpoints, final_type);
|
||||
std::vector<tools::shared_ptr<array> > temporaries_;
|
||||
|
||||
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
|
||||
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
|
||||
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
|
||||
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
|
||||
/*----Compute required temporaries----*/
|
||||
for(detail::breakpoints_t::reverse_iterator rit = breakpoints.rbegin() ; rit != breakpoints.rend() ; ++rit)
|
||||
{
|
||||
tools::shared_ptr<model> const & pmodel = models[std::make_pair(rit->first, dtype)];
|
||||
array_expression::node const & node = tree[rit->second->node_index];
|
||||
array_expression::node const & lmost = lhs_most(tree, node);
|
||||
|
||||
default: throw std::invalid_argument("Unrecognized operation");
|
||||
}
|
||||
temporaries_.push_back(tmp);
|
||||
//Creates temporary
|
||||
tools::shared_ptr<array> tmp;
|
||||
switch(rit->first){
|
||||
case SCALAR_AXPY_TYPE:
|
||||
case REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
|
||||
|
||||
tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE;
|
||||
fill(tree[rootidx].lhs, (array&)*tmp);
|
||||
tree[rootidx].rhs = *rit->second;
|
||||
tree[rootidx].rhs.type_family = rit->second->type_family;
|
||||
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||
|
||||
//Execute
|
||||
pmodel->execute(controller<expressions_tuple>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
||||
tree[rootidx] = root_save;
|
||||
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
|
||||
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
|
||||
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
|
||||
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
|
||||
|
||||
//Incorporates the temporary within the array_expression
|
||||
fill(*rit->second, (array&)*tmp);
|
||||
default: throw std::invalid_argument("Unrecognized operation");
|
||||
}
|
||||
temporaries_.push_back(tmp);
|
||||
|
||||
tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE;
|
||||
fill(tree[rootidx].lhs, (array&)*tmp);
|
||||
tree[rootidx].rhs = *rit->second;
|
||||
tree[rootidx].rhs.type_family = rit->second->type_family;
|
||||
|
||||
//Execute
|
||||
pmodel->execute(controller<expressions_tuple>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
||||
tree[rootidx] = root_save;
|
||||
|
||||
//Incorporates the temporary within the array_expression
|
||||
fill(*rit->second, (array&)*tmp);
|
||||
}
|
||||
}
|
||||
|
||||
/*-----Compute final expression-----*/
|
||||
|
Reference in New Issue
Block a user