Backend: Now not creating a temporary upon C = alpha*dot(op(A), op(B)) + beta*C

This commit is contained in:
Philippe Tillet
2015-06-27 17:55:01 -07:00
parent 3525edd54c
commit 0e207e7ca4
7 changed files with 206 additions and 67 deletions

View File

@@ -7,6 +7,7 @@
#include <CL/cl.hpp>
#include "isaac/model/model.h"
#include "isaac/symbolic/expression.h"
#include "isaac/symbolic/preset.h"
namespace isaac
{
@@ -157,7 +158,6 @@ namespace isaac
parse(array, node.rhs.node_index, next_type, breakpoints, final_type, false);
}
}
}
/** @brief Executes a array_expression on the given models map*/
@@ -172,61 +172,70 @@ namespace isaac
//Todo: technically the datatype should be per temporary
numeric_type dtype = root_save.lhs.dtype;
detail::breakpoints_t breakpoints;
breakpoints.reserve(8);
//Init
expression_type current_type;
if(root_save.lhs.array->nshape()==0)
current_type = SCALAR_AXPY_TYPE;
else if(root_save.lhs.array->nshape()==1)
current_type=VECTOR_AXPY_TYPE;
expression_type final_type;
//GEMM
if(symbolic::preset::gemm::args args = symbolic::preset::gemm::check(tree, rootidx)){
final_type = args.type;
}
//Default
else
current_type=MATRIX_AXPY_TYPE;
expression_type final_type = current_type;
/*----Parse required temporaries-----*/
detail::parse(tree, rootidx, current_type, breakpoints, final_type);
std::vector<tools::shared_ptr<array> > temporaries_;
/*----Compute required temporaries----*/
for(detail::breakpoints_t::reverse_iterator rit = breakpoints.rbegin() ; rit != breakpoints.rend() ; ++rit)
{
tools::shared_ptr<model> const & pmodel = models[std::make_pair(rit->first, dtype)];
array_expression::node const & node = tree[rit->second->node_index];
array_expression::node const & lmost = lhs_most(tree, node);
detail::breakpoints_t breakpoints;
breakpoints.reserve(8);
//Creates temporary
tools::shared_ptr<array> tmp;
switch(rit->first){
case SCALAR_AXPY_TYPE:
case REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
//Init
expression_type current_type;
if(root_save.lhs.array->nshape()==0)
current_type = SCALAR_AXPY_TYPE;
else if(root_save.lhs.array->nshape()==1)
current_type=VECTOR_AXPY_TYPE;
else
current_type=MATRIX_AXPY_TYPE;
final_type = current_type;
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
/*----Parse required temporaries-----*/
detail::parse(tree, rootidx, current_type, breakpoints, final_type);
std::vector<tools::shared_ptr<array> > temporaries_;
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
/*----Compute required temporaries----*/
for(detail::breakpoints_t::reverse_iterator rit = breakpoints.rbegin() ; rit != breakpoints.rend() ; ++rit)
{
tools::shared_ptr<model> const & pmodel = models[std::make_pair(rit->first, dtype)];
array_expression::node const & node = tree[rit->second->node_index];
array_expression::node const & lmost = lhs_most(tree, node);
default: throw std::invalid_argument("Unrecognized operation");
}
temporaries_.push_back(tmp);
//Creates temporary
tools::shared_ptr<array> tmp;
switch(rit->first){
case SCALAR_AXPY_TYPE:
case REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE;
fill(tree[rootidx].lhs, (array&)*tmp);
tree[rootidx].rhs = *rit->second;
tree[rootidx].rhs.type_family = rit->second->type_family;
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
//Execute
pmodel->execute(controller<expressions_tuple>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
tree[rootidx] = root_save;
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
//Incorporates the temporary within the array_expression
fill(*rit->second, (array&)*tmp);
default: throw std::invalid_argument("Unrecognized operation");
}
temporaries_.push_back(tmp);
tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE;
fill(tree[rootidx].lhs, (array&)*tmp);
tree[rootidx].rhs = *rit->second;
tree[rootidx].rhs.type_family = rit->second->type_family;
//Execute
pmodel->execute(controller<expressions_tuple>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
tree[rootidx] = root_save;
//Incorporates the temporary within the array_expression
fill(*rit->second, (array&)*tmp);
}
}
/*-----Compute final expression-----*/