Files
triton/lib/symbolic/execute.cpp

235 lines
10 KiB
C++

#include <assert.h>
#include <list>
#include <vector>
#include <stdexcept>
#include "isaac/types.h"
#include "isaac/array.h"
#include "isaac/profiles/profiles.h"
#include "isaac/symbolic/expression.h"
#include "isaac/symbolic/preset.h"
namespace isaac
{
namespace detail
{
typedef std::vector<std::pair<expression_type, lhs_rhs_element*> > breakpoints_t;
inline bool is_mmprod(expression_type x)
{
return x==GEMM_NN_TYPE || x==GEMM_NT_TYPE ||
x==GEMM_TN_TYPE || x==GEMM_TT_TYPE;
}
inline bool is_mvprod(expression_type x)
{
return x==GEMV_N_TYPE || x==GEMV_T_TYPE;
}
inline bool has_temporary_impl(op_element op, expression_type expression, expression_type other, bool is_first)
{
bool result = false;
switch(op.type_family)
{
case OPERATOR_UNARY_TYPE_FAMILY:
case OPERATOR_BINARY_TYPE_FAMILY:
result |= is_mmprod(expression)
|| (result |= expression==GEMV_N_TYPE && other==GEMV_T_TYPE)
|| (result |= expression==GEMV_T_TYPE && other==GEMV_N_TYPE);
break;
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
result |= is_mvprod(expression)
|| expression==DOT_TYPE;
break;
case OPERATOR_ROWS_DOT_TYPE_FAMILY:
result |= is_mmprod(expression)
|| is_mvprod(expression)
|| expression==DOT_TYPE;
break;
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
result |= is_mmprod(expression)
|| is_mvprod(expression)
|| expression==DOT_TYPE;
break;
case OPERATOR_GEMM_TYPE_FAMILY:
result |= (is_mmprod(expression) && !is_first)
|| is_mvprod(expression)
|| expression==DOT_TYPE;
break;
default:
break;
}
return result;
}
inline std::pair<bool, bool> has_temporary(op_element op, expression_type left, expression_type right, bool is_first)
{
bool has_temporary_left = has_temporary_impl(op, left, right, is_first);
bool has_temporary_right = has_temporary_impl(op, right, left, is_first);
return std::make_pair(has_temporary_left, has_temporary_right);
}
inline expression_type merge(op_element op, expression_type left, expression_type right)
{
switch(op.type_family)
{
case OPERATOR_UNARY_TYPE_FAMILY:
if(is_mmprod(left))
return GER_TYPE;
return left;
case OPERATOR_BINARY_TYPE_FAMILY:
if(left == GEMV_N_TYPE || right == GEMV_N_TYPE) return GEMV_N_TYPE;
else if(left == GEMV_T_TYPE || right == GEMV_T_TYPE) return GEMV_T_TYPE;
else if(left == DOT_TYPE || right == DOT_TYPE) return DOT_TYPE;
else if(left == GER_TYPE || right == GER_TYPE) return GER_TYPE;
else if(left == AXPY_TYPE || right == AXPY_TYPE) return op.type==OPERATOR_OUTER_PROD_TYPE?GER_TYPE:AXPY_TYPE;
else if(is_mmprod(left) || is_mmprod(right)) return GER_TYPE;
else if(right == INVALID_EXPRESSION_TYPE) return left;
else if(left == INVALID_EXPRESSION_TYPE) return right;
throw;
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
return DOT_TYPE;
case OPERATOR_ROWS_DOT_TYPE_FAMILY:
return GEMV_N_TYPE;
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
return GEMV_T_TYPE;
case OPERATOR_GEMM_TYPE_FAMILY:
if(op.type==OPERATOR_GEMM_NN_TYPE) return GEMM_NN_TYPE;
else if(op.type==OPERATOR_GEMM_TN_TYPE) return GEMM_TN_TYPE;
else if(op.type==OPERATOR_GEMM_NT_TYPE) return GEMM_NT_TYPE;
else return GEMM_TT_TYPE;
default:
throw;
}
}
/** @brief Parses the breakpoints for a given expression tree */
static void parse(math_expression::container_type&array, size_t idx,
breakpoints_t & breakpoints,
expression_type & final_type,
bool is_first = true)
{
math_expression::node & node = array[idx];
auto ng1 = [](shape_t const & shape){ size_t res = 0 ; for(size_t i = 0 ; i < shape.size() ; ++i) res += (shape[i] > 1); return res;};
//Left
expression_type type_left = INVALID_EXPRESSION_TYPE;
if (node.lhs.type_family == COMPOSITE_OPERATOR_FAMILY)
parse(array, node.lhs.node_index, breakpoints, type_left, false);
else if(node.lhs.subtype == DENSE_ARRAY_TYPE)
{
if(node.op.type==OPERATOR_MATRIX_ROW_TYPE || node.op.type==OPERATOR_MATRIX_COLUMN_TYPE || ng1(node.lhs.array->shape())<=1)
type_left = AXPY_TYPE;
else
type_left = GER_TYPE;
}
//Right
expression_type type_right = INVALID_EXPRESSION_TYPE;
if (node.rhs.type_family == COMPOSITE_OPERATOR_FAMILY)
parse(array, node.rhs.node_index, breakpoints, type_right, false);
else if(node.rhs.subtype == DENSE_ARRAY_TYPE)
{
if(node.op.type==OPERATOR_MATRIX_ROW_TYPE || node.op.type==OPERATOR_MATRIX_COLUMN_TYPE || ng1(node.rhs.array->shape())<=1)
type_right = AXPY_TYPE;
else
type_right = GER_TYPE;
}
final_type = merge(array[idx].op, type_left, type_right);
std::pair<bool, bool> tmp = has_temporary(array[idx].op, type_left, type_right, is_first);
if(tmp.first)
breakpoints.push_back(std::make_pair(type_left, &array[idx].lhs));
if(tmp.second)
breakpoints.push_back(std::make_pair(type_right, &array[idx].rhs));
}
}
/** @brief Executes a math_expression on the given models map*/
void execute(execution_handler const & c, profiles::map_type & profiles)
{
math_expression expression = c.x();
driver::Context const & context = expression.context();
size_t rootidx = expression.root();
math_expression::container_type & tree = const_cast<math_expression::container_type &>(expression.tree());
math_expression::node root_save = tree[rootidx];
//Todo: technically the datatype should be per temporary
numeric_type dtype = expression.dtype();
std::vector<std::shared_ptr<array> > temporaries_;
expression_type final_type;
//GEMM
if(symbolic::preset::gemm::args args = symbolic::preset::gemm::check(tree, rootidx)){
final_type = args.type;
}
//Default
else
{
detail::breakpoints_t breakpoints;
breakpoints.reserve(8);
//Init
expression_type current_type;
auto ng1 = [](shape_t const & shape){ size_t res = 0 ; for(size_t i = 0 ; i < shape.size() ; ++i) res += (shape[i] > 1); return res;};
if(ng1(expression.shape())<=1)
current_type=AXPY_TYPE;
else
current_type=GER_TYPE;
final_type = current_type;
/*----Parse required temporaries-----*/
detail::parse(tree, rootidx, breakpoints, final_type);
/*----Compute required temporaries----*/
for(detail::breakpoints_t::iterator it = breakpoints.begin() ; it != breakpoints.end() ; ++it)
{
std::shared_ptr<profiles::value_type> const & profile = profiles[std::make_pair(it->first, dtype)];
math_expression::node const & node = tree[it->second->node_index];
math_expression::node const & lmost = lhs_most(tree, node);
//Creates temporary
std::shared_ptr<array> tmp;
switch(it->first){
case DOT_TYPE: tmp = std::shared_ptr<array>(new array(1, dtype, context)); break;
case AXPY_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case GEMV_N_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case GEMV_T_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
case GER_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
case GEMM_NN_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
case GEMM_NT_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
case GEMM_TN_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
case GEMM_TT_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
default: throw std::invalid_argument("Unrecognized operation");
}
temporaries_.push_back(tmp);
tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE;
fill(tree[rootidx].lhs, (array&)*tmp);
tree[rootidx].rhs = *it->second;
tree[rootidx].rhs.type_family = it->second->type_family;
//Execute
profile->execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
tree[rootidx] = root_save;
//Incorporates the temporary within, the math_expression
fill(*it->second, (array&)*tmp);
}
}
/*-----Compute final expression-----*/
profiles[std::make_pair(final_type, dtype)]->execute(execution_handler(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
}
void execute(execution_handler const & c)
{
execute(c, isaac::profiles::get(c.execution_options().queue(c.x().context())));
}
}