Backend: Now not creating a temporary upon C = alpha*dot(op(A), op(B)) + beta*C

This commit is contained in:
Philippe Tillet
2015-06-27 17:55:01 -07:00
parent 3525edd54c
commit 0e207e7ca4
7 changed files with 206 additions and 67 deletions

View File

@@ -7,6 +7,7 @@
#include <CL/cl.hpp>
#include "isaac/model/model.h"
#include "isaac/symbolic/expression.h"
#include "isaac/symbolic/preset.h"
namespace isaac
{
@@ -157,7 +158,6 @@ namespace isaac
parse(array, node.rhs.node_index, next_type, breakpoints, final_type, false);
}
}
}
/** @brief Executes a array_expression on the given models map*/
@@ -172,61 +172,70 @@ namespace isaac
//Todo: technically the datatype should be per temporary
numeric_type dtype = root_save.lhs.dtype;
detail::breakpoints_t breakpoints;
breakpoints.reserve(8);
//Init
expression_type current_type;
if(root_save.lhs.array->nshape()==0)
current_type = SCALAR_AXPY_TYPE;
else if(root_save.lhs.array->nshape()==1)
current_type=VECTOR_AXPY_TYPE;
expression_type final_type;
//GEMM
if(symbolic::preset::gemm::args args = symbolic::preset::gemm::check(tree, rootidx)){
final_type = args.type;
}
//Default
else
current_type=MATRIX_AXPY_TYPE;
expression_type final_type = current_type;
/*----Parse required temporaries-----*/
detail::parse(tree, rootidx, current_type, breakpoints, final_type);
std::vector<tools::shared_ptr<array> > temporaries_;
/*----Compute required temporaries----*/
for(detail::breakpoints_t::reverse_iterator rit = breakpoints.rbegin() ; rit != breakpoints.rend() ; ++rit)
{
tools::shared_ptr<model> const & pmodel = models[std::make_pair(rit->first, dtype)];
array_expression::node const & node = tree[rit->second->node_index];
array_expression::node const & lmost = lhs_most(tree, node);
detail::breakpoints_t breakpoints;
breakpoints.reserve(8);
//Creates temporary
tools::shared_ptr<array> tmp;
switch(rit->first){
case SCALAR_AXPY_TYPE:
case REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
//Init
expression_type current_type;
if(root_save.lhs.array->nshape()==0)
current_type = SCALAR_AXPY_TYPE;
else if(root_save.lhs.array->nshape()==1)
current_type=VECTOR_AXPY_TYPE;
else
current_type=MATRIX_AXPY_TYPE;
final_type = current_type;
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
/*----Parse required temporaries-----*/
detail::parse(tree, rootidx, current_type, breakpoints, final_type);
std::vector<tools::shared_ptr<array> > temporaries_;
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
/*----Compute required temporaries----*/
for(detail::breakpoints_t::reverse_iterator rit = breakpoints.rbegin() ; rit != breakpoints.rend() ; ++rit)
{
tools::shared_ptr<model> const & pmodel = models[std::make_pair(rit->first, dtype)];
array_expression::node const & node = tree[rit->second->node_index];
array_expression::node const & lmost = lhs_most(tree, node);
default: throw std::invalid_argument("Unrecognized operation");
}
temporaries_.push_back(tmp);
//Creates temporary
tools::shared_ptr<array> tmp;
switch(rit->first){
case SCALAR_AXPY_TYPE:
case REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE;
fill(tree[rootidx].lhs, (array&)*tmp);
tree[rootidx].rhs = *rit->second;
tree[rootidx].rhs.type_family = rit->second->type_family;
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
//Execute
pmodel->execute(controller<expressions_tuple>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
tree[rootidx] = root_save;
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
//Incorporates the temporary within the array_expression
fill(*rit->second, (array&)*tmp);
default: throw std::invalid_argument("Unrecognized operation");
}
temporaries_.push_back(tmp);
tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE;
fill(tree[rootidx].lhs, (array&)*tmp);
tree[rootidx].rhs = *rit->second;
tree[rootidx].rhs.type_family = rit->second->type_family;
//Execute
pmodel->execute(controller<expressions_tuple>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
tree[rootidx] = root_save;
//Incorporates the temporary within the array_expression
fill(*rit->second, (array&)*tmp);
}
}
/*-----Compute final expression-----*/

View File

@@ -4,6 +4,7 @@
#include "isaac/value_scalar.h"
#include <CL/cl.hpp>
#include "isaac/symbolic/expression.h"
#include "isaac/symbolic/preset.h"
namespace isaac
{

78
lib/symbolic/preset.cpp Normal file
View File

@@ -0,0 +1,78 @@
#include "isaac/symbolic/preset.h"
namespace isaac
{
namespace symbolic
{
namespace preset
{
void gemm::handle_node(array_expression::container_type &tree, size_t rootidx, args & a)
{
//Matrix-Matrix product node
if(tree[rootidx].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
{
if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY) a.A = &tree[rootidx].lhs;
if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs;
switch(tree[rootidx].op.type)
{
case OPERATOR_MATRIX_PRODUCT_NN_TYPE: a.type = MATRIX_PRODUCT_NN_TYPE; break;
case OPERATOR_MATRIX_PRODUCT_NT_TYPE: a.type = MATRIX_PRODUCT_NT_TYPE; break;
case OPERATOR_MATRIX_PRODUCT_TN_TYPE: a.type = MATRIX_PRODUCT_TN_TYPE; break;
case OPERATOR_MATRIX_PRODUCT_TT_TYPE: a.type = MATRIX_PRODUCT_TT_TYPE; break;
default: break;
}
}
//Scalar multiplication node
if(tree[rootidx].op.type==OPERATOR_MULT_TYPE)
{
//alpha*PROD
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY
&& tree[tree[rootidx].rhs.node_index].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
{
a.alpha = &tree[rootidx].lhs;
handle_node(tree, tree[rootidx].rhs.node_index, a);
}
//beta*C
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY)
{
a.beta = &tree[rootidx].lhs;
a.C = &tree[rootidx].rhs;
}
}
}
gemm::args gemm::check(array_expression::container_type & tree, size_t rootidx)
{
lhs_rhs_element * assigned = &tree[rootidx].lhs;
gemm::args result ;
if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
{
rootidx = tree[rootidx].rhs.node_index;
//Form X + Y
if(tree[rootidx].op.type==OPERATOR_ADD_TYPE || tree[rootidx].op.type==OPERATOR_SUB_TYPE)
{
if(tree[rootidx].lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
handle_node(tree, tree[rootidx].lhs.node_index, result);
if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
handle_node(tree, tree[rootidx].rhs.node_index, result);
}
else
handle_node(tree, rootidx, result);
}
if(result.C == NULL)
result.C = assigned;
else if(result.C->array != assigned->array)
result.C = NULL;
return result;
}
}
}
}