C++: improved temporaries handling
This commit is contained in:
@@ -16,147 +16,133 @@ namespace isaac
|
|||||||
{
|
{
|
||||||
typedef std::vector<std::pair<expression_type, lhs_rhs_element*> > breakpoints_t;
|
typedef std::vector<std::pair<expression_type, lhs_rhs_element*> > breakpoints_t;
|
||||||
|
|
||||||
/** @brief Determine if a particular operation requires a breakpoint
|
|
||||||
*
|
inline bool is_mmprod(expression_type x)
|
||||||
* @return std::pair<bool, expression_type> the first element is weather or not a breakpoint is required
|
|
||||||
* The second element is the type of the new operation
|
|
||||||
*/
|
|
||||||
static std::pair<bool, expression_type> is_breakpoint(expression_type current_type, op_element op, bool is_first)
|
|
||||||
{
|
{
|
||||||
using std::make_pair;
|
return x==MATRIX_PRODUCT_NN_TYPE || x==MATRIX_PRODUCT_NT_TYPE ||
|
||||||
|
x==MATRIX_PRODUCT_TN_TYPE || x==MATRIX_PRODUCT_TT_TYPE;
|
||||||
|
}
|
||||||
|
|
||||||
switch(current_type)
|
inline bool is_mvprod(expression_type x)
|
||||||
{
|
{
|
||||||
|
return x==ROW_WISE_REDUCTION_TYPE || x==COL_WISE_REDUCTION_TYPE;
|
||||||
//BLAS1 Helpers
|
|
||||||
#define HANDLE_VECTOR_AXPY(tmp) case OPERATOR_BINARY_TYPE_FAMILY:\
|
|
||||||
case OPERATOR_UNARY_TYPE_FAMILY: return make_pair(tmp, VECTOR_AXPY_TYPE)
|
|
||||||
#define HANDLE_VECTOR_REDUCTION(tmp) case OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY: return make_pair(tmp, REDUCTION_TYPE)
|
|
||||||
#define HANDLE_MATRIX_AXPY(tmp) case OPERATOR_BINARY_TYPE_FAMILY:\
|
|
||||||
case OPERATOR_UNARY_TYPE_FAMILY: return make_pair(tmp, MATRIX_AXPY_TYPE)
|
|
||||||
|
|
||||||
//BLAS2 Helpers
|
|
||||||
#define HANDLE_ROWS_REDUCTION(tmp) case OPERATOR_ROWS_REDUCTION_TYPE_FAMILY: return make_pair(tmp, ROW_WISE_REDUCTION_TYPE)
|
|
||||||
#define HANDLE_COLUMNS_REDUCTION(tmp) case OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY: return make_pair(tmp, COL_WISE_REDUCTION_TYPE)
|
|
||||||
|
|
||||||
//BLAS3 Helpers
|
|
||||||
#define HANDLE_MATRIX_PRODUCT(tmp) case OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY:\
|
|
||||||
switch(op.type){\
|
|
||||||
case OPERATOR_MATRIX_PRODUCT_NN_TYPE: return make_pair(tmp, MATRIX_PRODUCT_NN_TYPE);\
|
|
||||||
case OPERATOR_MATRIX_PRODUCT_TN_TYPE: return make_pair(tmp, MATRIX_PRODUCT_TN_TYPE);\
|
|
||||||
case OPERATOR_MATRIX_PRODUCT_NT_TYPE: return make_pair(tmp, MATRIX_PRODUCT_NT_TYPE);\
|
|
||||||
case OPERATOR_MATRIX_PRODUCT_TT_TYPE: return make_pair(tmp, MATRIX_PRODUCT_TT_TYPE);\
|
|
||||||
default: assert(false && "This misformed expression shouldn't occur");\
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Inside a SCALAR AXPY
|
inline bool has_temporary_impl(op_element op, expression_type expression, expression_type other, bool is_first)
|
||||||
case SCALAR_AXPY_TYPE:
|
{
|
||||||
// Reduction: No temporary
|
bool result = false;
|
||||||
switch(op.type_family){
|
switch(op.type_family)
|
||||||
HANDLE_VECTOR_REDUCTION(false);
|
{
|
||||||
default: break;
|
case OPERATOR_UNARY_TYPE_FAMILY:
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
// Inside a VECTOR AXPY
|
|
||||||
case VECTOR_AXPY_TYPE:
|
|
||||||
switch(op.type_family){
|
|
||||||
HANDLE_VECTOR_REDUCTION(true);
|
|
||||||
HANDLE_ROWS_REDUCTION(false);
|
|
||||||
HANDLE_COLUMNS_REDUCTION(false);
|
|
||||||
default: break;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
// Inside a REDUCTION
|
|
||||||
case REDUCTION_TYPE:
|
|
||||||
switch(op.type_family){
|
|
||||||
HANDLE_VECTOR_REDUCTION(true);
|
|
||||||
default: break;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case COL_WISE_REDUCTION_TYPE:
|
|
||||||
case ROW_WISE_REDUCTION_TYPE:
|
|
||||||
switch(op.type_family){
|
|
||||||
case OPERATOR_BINARY_TYPE_FAMILY:
|
case OPERATOR_BINARY_TYPE_FAMILY:
|
||||||
case OPERATOR_UNARY_TYPE_FAMILY: return std::make_pair(false, current_type);
|
result |= is_mmprod(expression)
|
||||||
HANDLE_ROWS_REDUCTION(true);
|
|| (result |= expression==ROW_WISE_REDUCTION_TYPE && other==COL_WISE_REDUCTION_TYPE)
|
||||||
HANDLE_COLUMNS_REDUCTION(true);
|
|| (result |= expression==COL_WISE_REDUCTION_TYPE && other==ROW_WISE_REDUCTION_TYPE);
|
||||||
HANDLE_VECTOR_REDUCTION(true);
|
|
||||||
HANDLE_MATRIX_PRODUCT(true);
|
|
||||||
default: break;
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
|
case OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY:
|
||||||
// Inside a MATRIX AXPY
|
result |= is_mvprod(expression)
|
||||||
// - MATRIX PRODUCTS are temporaries
|
|| expression==REDUCTION_TYPE;
|
||||||
// - REDUCTIONS are temporaries
|
break;
|
||||||
case MATRIX_AXPY_TYPE:
|
case OPERATOR_ROWS_REDUCTION_TYPE_FAMILY:
|
||||||
switch(op.type_family){
|
result |= is_mmprod(expression)
|
||||||
HANDLE_VECTOR_REDUCTION(true);
|
|| is_mvprod(expression)
|
||||||
HANDLE_MATRIX_PRODUCT(!is_first);
|
|| expression==REDUCTION_TYPE;
|
||||||
default: break;
|
break;
|
||||||
}
|
case OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY:
|
||||||
|
result |= is_mmprod(expression)
|
||||||
|
|| is_mvprod(expression)
|
||||||
|
|| expression==REDUCTION_TYPE;
|
||||||
|
break;
|
||||||
|
case OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY:
|
||||||
|
result |= (is_mmprod(expression) && !is_first)
|
||||||
|
|| is_mvprod(expression)
|
||||||
|
|| expression==REDUCTION_TYPE;
|
||||||
break;
|
break;
|
||||||
|
|
||||||
// Inside a MATRIX PRODUCT:
|
|
||||||
// - AXPY are temporaries
|
|
||||||
// - MATRIX PRODUCTS are temporaries
|
|
||||||
case MATRIX_PRODUCT_NN_TYPE:
|
|
||||||
case MATRIX_PRODUCT_NT_TYPE:
|
|
||||||
case MATRIX_PRODUCT_TN_TYPE:
|
|
||||||
case MATRIX_PRODUCT_TT_TYPE:
|
|
||||||
switch(op.type_family){
|
|
||||||
HANDLE_MATRIX_AXPY(true);
|
|
||||||
HANDLE_MATRIX_PRODUCT(true);
|
|
||||||
HANDLE_VECTOR_REDUCTION(true);
|
|
||||||
default: break;
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
#undef HANDLE_VECTOR_AXPY
|
inline std::pair<bool, bool> has_temporary(op_element op, expression_type left, expression_type right, bool is_first)
|
||||||
#undef HANDLE_VECTOR_REDUCTION
|
{
|
||||||
#undef HANDLE_MATRIX_AXPY
|
bool has_temporary_left = has_temporary_impl(op, left, right, is_first);
|
||||||
#undef HANDLE_ROWS_REDUCTION
|
bool has_temporary_right = has_temporary_impl(op, right, left, is_first);
|
||||||
#undef HANDLE_COLUMN_REDUCTION
|
return std::make_pair(has_temporary_left, has_temporary_right);
|
||||||
#undef HANDLE_MATRIX_PRODUCT
|
}
|
||||||
|
|
||||||
return make_pair(false, current_type);
|
inline expression_type merge(op_element op, expression_type left, expression_type right)
|
||||||
|
{
|
||||||
|
switch(op.type_family)
|
||||||
|
{
|
||||||
|
case OPERATOR_UNARY_TYPE_FAMILY:
|
||||||
|
if(is_mmprod(left))
|
||||||
|
return MATRIX_AXPY_TYPE;
|
||||||
|
return left;
|
||||||
|
case OPERATOR_BINARY_TYPE_FAMILY:
|
||||||
|
if(left == ROW_WISE_REDUCTION_TYPE || right == ROW_WISE_REDUCTION_TYPE) return ROW_WISE_REDUCTION_TYPE;
|
||||||
|
else if(left == COL_WISE_REDUCTION_TYPE || right == COL_WISE_REDUCTION_TYPE) return COL_WISE_REDUCTION_TYPE;
|
||||||
|
else if(left == REDUCTION_TYPE || right == REDUCTION_TYPE) return REDUCTION_TYPE;
|
||||||
|
else if(left == VECTOR_AXPY_TYPE || right == VECTOR_AXPY_TYPE) return op.type==OPERATOR_OUTER_PROD_TYPE?MATRIX_AXPY_TYPE:VECTOR_AXPY_TYPE;
|
||||||
|
else if(left == MATRIX_AXPY_TYPE || right == MATRIX_AXPY_TYPE) return MATRIX_AXPY_TYPE;
|
||||||
|
else if(is_mmprod(left) || is_mmprod(right)) return MATRIX_AXPY_TYPE;
|
||||||
|
std::cout << left << " " << right << std::endl;
|
||||||
|
throw;
|
||||||
|
case OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY:
|
||||||
|
return REDUCTION_TYPE;
|
||||||
|
case OPERATOR_ROWS_REDUCTION_TYPE_FAMILY:
|
||||||
|
return ROW_WISE_REDUCTION_TYPE;
|
||||||
|
case OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY:
|
||||||
|
return COL_WISE_REDUCTION_TYPE;
|
||||||
|
case OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY:
|
||||||
|
if(op.type==OPERATOR_MATRIX_PRODUCT_NN_TYPE) return MATRIX_PRODUCT_NN_TYPE;
|
||||||
|
else if(op.type==OPERATOR_MATRIX_PRODUCT_TN_TYPE) return MATRIX_PRODUCT_TN_TYPE;
|
||||||
|
else if(op.type==OPERATOR_MATRIX_PRODUCT_NT_TYPE) return MATRIX_PRODUCT_NT_TYPE;
|
||||||
|
else return MATRIX_PRODUCT_TT_TYPE;
|
||||||
|
default:
|
||||||
|
throw;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** @brief Parses the breakpoints for a given expression tree */
|
/** @brief Parses the breakpoints for a given expression tree */
|
||||||
static void parse(array_expression::container_type&array, size_t idx,
|
static void parse(array_expression::container_type&array, size_t idx,
|
||||||
expression_type current_type,
|
|
||||||
breakpoints_t & breakpoints,
|
breakpoints_t & breakpoints,
|
||||||
expression_type & final_type,
|
expression_type & final_type,
|
||||||
bool is_first = true)
|
bool is_first = true)
|
||||||
{
|
{
|
||||||
array_expression::node & node = array[idx];
|
array_expression::node & node = array[idx];
|
||||||
|
|
||||||
|
//Left
|
||||||
|
expression_type type_left = INVALID_EXPRESSION_TYPE;
|
||||||
if (node.lhs.type_family == COMPOSITE_OPERATOR_FAMILY)
|
if (node.lhs.type_family == COMPOSITE_OPERATOR_FAMILY)
|
||||||
|
parse(array, node.lhs.node_index, breakpoints, type_left, false);
|
||||||
|
else if(node.lhs.subtype == DENSE_ARRAY_TYPE)
|
||||||
{
|
{
|
||||||
std::pair<bool, expression_type> breakpoint = is_breakpoint(current_type, array[node.lhs.node_index].op, is_first);
|
if(node.lhs.array->nshape()==1)
|
||||||
expression_type next_type = breakpoint.second;
|
type_left = VECTOR_AXPY_TYPE;
|
||||||
if(breakpoint.first)
|
|
||||||
breakpoints.push_back(std::make_pair(next_type, &node.lhs));
|
|
||||||
else
|
else
|
||||||
final_type = next_type;
|
type_left = MATRIX_AXPY_TYPE;
|
||||||
parse(array, node.lhs.node_index, next_type, breakpoints, final_type, false);
|
|
||||||
}
|
}
|
||||||
current_type = final_type;
|
|
||||||
|
//Right
|
||||||
|
expression_type type_right = INVALID_EXPRESSION_TYPE;
|
||||||
if (node.rhs.type_family == COMPOSITE_OPERATOR_FAMILY)
|
if (node.rhs.type_family == COMPOSITE_OPERATOR_FAMILY)
|
||||||
|
parse(array, node.rhs.node_index, breakpoints, type_right, false);
|
||||||
|
else if(node.rhs.subtype == DENSE_ARRAY_TYPE)
|
||||||
{
|
{
|
||||||
std::pair<bool, expression_type> breakpoint = is_breakpoint(current_type, array[node.rhs.node_index].op, is_first);
|
if(node.rhs.array->nshape()==1)
|
||||||
expression_type next_type = breakpoint.second;
|
type_right = VECTOR_AXPY_TYPE;
|
||||||
if(breakpoint.first)
|
|
||||||
breakpoints.push_back(std::make_pair(next_type, &node.rhs));
|
|
||||||
else
|
else
|
||||||
final_type = next_type;
|
type_right = MATRIX_AXPY_TYPE;
|
||||||
parse(array, node.rhs.node_index, next_type, breakpoints, final_type, false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
final_type = merge(array[idx].op, type_left, type_right);
|
||||||
|
std::pair<bool, bool> tmp = has_temporary(array[idx].op, type_left, type_right, is_first);
|
||||||
|
if(tmp.first)
|
||||||
|
breakpoints.push_back(std::make_pair(type_left, &array[idx].lhs));
|
||||||
|
if(tmp.second)
|
||||||
|
breakpoints.push_back(std::make_pair(type_right, &array[idx].rhs));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -194,19 +180,19 @@ namespace isaac
|
|||||||
final_type = current_type;
|
final_type = current_type;
|
||||||
|
|
||||||
/*----Parse required temporaries-----*/
|
/*----Parse required temporaries-----*/
|
||||||
detail::parse(tree, rootidx, current_type, breakpoints, final_type);
|
detail::parse(tree, rootidx, breakpoints, final_type);
|
||||||
std::vector<tools::shared_ptr<array> > temporaries_;
|
std::vector<tools::shared_ptr<array> > temporaries_;
|
||||||
|
|
||||||
/*----Compute required temporaries----*/
|
/*----Compute required temporaries----*/
|
||||||
for(detail::breakpoints_t::reverse_iterator rit = breakpoints.rbegin() ; rit != breakpoints.rend() ; ++rit)
|
for(detail::breakpoints_t::iterator it = breakpoints.begin() ; it != breakpoints.end() ; ++it)
|
||||||
{
|
{
|
||||||
tools::shared_ptr<model> const & pmodel = models[std::make_pair(rit->first, dtype)];
|
tools::shared_ptr<model> const & pmodel = models[std::make_pair(it->first, dtype)];
|
||||||
array_expression::node const & node = tree[rit->second->node_index];
|
array_expression::node const & node = tree[it->second->node_index];
|
||||||
array_expression::node const & lmost = lhs_most(tree, node);
|
array_expression::node const & lmost = lhs_most(tree, node);
|
||||||
|
|
||||||
//Creates temporary
|
//Creates temporary
|
||||||
tools::shared_ptr<array> tmp;
|
tools::shared_ptr<array> tmp;
|
||||||
switch(rit->first){
|
switch(it->first){
|
||||||
case SCALAR_AXPY_TYPE:
|
case SCALAR_AXPY_TYPE:
|
||||||
case REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
|
case REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
|
||||||
|
|
||||||
@@ -226,15 +212,15 @@ namespace isaac
|
|||||||
|
|
||||||
tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE;
|
tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE;
|
||||||
fill(tree[rootidx].lhs, (array&)*tmp);
|
fill(tree[rootidx].lhs, (array&)*tmp);
|
||||||
tree[rootidx].rhs = *rit->second;
|
tree[rootidx].rhs = *it->second;
|
||||||
tree[rootidx].rhs.type_family = rit->second->type_family;
|
tree[rootidx].rhs.type_family = it->second->type_family;
|
||||||
|
|
||||||
//Execute
|
//Execute
|
||||||
pmodel->execute(controller<expressions_tuple>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
pmodel->execute(controller<expressions_tuple>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
||||||
tree[rootidx] = root_save;
|
tree[rootidx] = root_save;
|
||||||
|
|
||||||
//Incorporates the temporary within the array_expression
|
//Incorporates the temporary within the array_expression
|
||||||
fill(*rit->second, (array&)*tmp);
|
fill(*it->second, (array&)*tmp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -71,7 +71,7 @@ def main():
|
|||||||
#Includes
|
#Includes
|
||||||
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
|
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
|
||||||
#Sources
|
#Sources
|
||||||
src = 'src/lib/wrap/clBLAS.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/event.cpp src/lib/driver/device.cpp src/lib/driver/context.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/backend/parse.cpp src/lib/backend/templates/mproduct.cpp src/lib/backend/templates/vaxpy.cpp src/lib/backend/templates/reduction.cpp src/lib/backend/templates/mreduction.cpp src/lib/backend/templates/maxpy.cpp src/lib/backend/templates/base.cpp src/lib/backend/stream.cpp src/lib/backend/mapped_object.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
|
src = 'src/lib/symbolic/preset.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/backend/templates/mreduction.cpp src/lib/backend/templates/reduction.cpp src/lib/backend/templates/mproduct.cpp src/lib/backend/templates/maxpy.cpp src/lib/backend/templates/base.cpp src/lib/backend/templates/vaxpy.cpp src/lib/backend/mapped_object.cpp src/lib/backend/stream.cpp src/lib/backend/parse.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/driver/backend.cpp src/lib/driver/device.cpp src/lib/driver/kernel.cpp src/lib/driver/buffer.cpp src/lib/driver/platform.cpp src/lib/driver/check.cpp src/lib/driver/program.cpp src/lib/driver/command_queue.cpp src/lib/driver/context.cpp src/lib/driver/event.cpp src/lib/driver/ndrange.cpp src/lib/driver/handle.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/wrap/clBLAS.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
|
||||||
boostsrc = 'external/boost/libs/'
|
boostsrc = 'external/boost/libs/'
|
||||||
for s in ['numpy','python','smart_ptr','system','thread']:
|
for s in ['numpy','python','smart_ptr','system','thread']:
|
||||||
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]
|
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]
|
||||||
|
Reference in New Issue
Block a user