Cleaning: Largely renamed templates to BLAS-like names
This commit is contained in:
@@ -19,13 +19,13 @@ namespace isaac
|
||||
|
||||
inline bool is_mmprod(expression_type x)
|
||||
{
|
||||
return x==MATRIX_PRODUCT_NN_TYPE || x==MATRIX_PRODUCT_NT_TYPE ||
|
||||
x==MATRIX_PRODUCT_TN_TYPE || x==MATRIX_PRODUCT_TT_TYPE;
|
||||
return x==GEMM_NN_TYPE || x==GEMM_NT_TYPE ||
|
||||
x==GEMM_TN_TYPE || x==GEMM_TT_TYPE;
|
||||
}
|
||||
|
||||
inline bool is_mvprod(expression_type x)
|
||||
{
|
||||
return x==ROW_WISE_REDUCTION_TYPE || x==COL_WISE_REDUCTION_TYPE;
|
||||
return x==GEMV_N_TYPE || x==GEMV_T_TYPE;
|
||||
}
|
||||
|
||||
inline bool has_temporary_impl(op_element op, expression_type expression, expression_type other, bool is_first)
|
||||
@@ -36,27 +36,27 @@ namespace isaac
|
||||
case OPERATOR_UNARY_TYPE_FAMILY:
|
||||
case OPERATOR_BINARY_TYPE_FAMILY:
|
||||
result |= is_mmprod(expression)
|
||||
|| (result |= expression==ROW_WISE_REDUCTION_TYPE && other==COL_WISE_REDUCTION_TYPE)
|
||||
|| (result |= expression==COL_WISE_REDUCTION_TYPE && other==ROW_WISE_REDUCTION_TYPE);
|
||||
|| (result |= expression==GEMV_N_TYPE && other==GEMV_T_TYPE)
|
||||
|| (result |= expression==GEMV_T_TYPE && other==GEMV_N_TYPE);
|
||||
break;
|
||||
case OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY:
|
||||
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
|
||||
result |= is_mvprod(expression)
|
||||
|| expression==REDUCTION_TYPE;
|
||||
|| expression==DOT_TYPE;
|
||||
break;
|
||||
case OPERATOR_ROWS_REDUCTION_TYPE_FAMILY:
|
||||
case OPERATOR_ROWS_DOT_TYPE_FAMILY:
|
||||
result |= is_mmprod(expression)
|
||||
|| is_mvprod(expression)
|
||||
|| expression==REDUCTION_TYPE;
|
||||
|| expression==DOT_TYPE;
|
||||
break;
|
||||
case OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY:
|
||||
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
|
||||
result |= is_mmprod(expression)
|
||||
|| is_mvprod(expression)
|
||||
|| expression==REDUCTION_TYPE;
|
||||
|| expression==DOT_TYPE;
|
||||
break;
|
||||
case OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY:
|
||||
case OPERATOR_GEMM_TYPE_FAMILY:
|
||||
result |= (is_mmprod(expression) && !is_first)
|
||||
|| is_mvprod(expression)
|
||||
|| expression==REDUCTION_TYPE;
|
||||
|| expression==DOT_TYPE;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@@ -77,28 +77,28 @@ namespace isaac
|
||||
{
|
||||
case OPERATOR_UNARY_TYPE_FAMILY:
|
||||
if(is_mmprod(left))
|
||||
return MATRIX_AXPY_TYPE;
|
||||
return GER_TYPE;
|
||||
return left;
|
||||
case OPERATOR_BINARY_TYPE_FAMILY:
|
||||
if(left == ROW_WISE_REDUCTION_TYPE || right == ROW_WISE_REDUCTION_TYPE) return ROW_WISE_REDUCTION_TYPE;
|
||||
else if(left == COL_WISE_REDUCTION_TYPE || right == COL_WISE_REDUCTION_TYPE) return COL_WISE_REDUCTION_TYPE;
|
||||
else if(left == REDUCTION_TYPE || right == REDUCTION_TYPE) return REDUCTION_TYPE;
|
||||
else if(left == VECTOR_AXPY_TYPE || right == VECTOR_AXPY_TYPE) return op.type==OPERATOR_OUTER_PROD_TYPE?MATRIX_AXPY_TYPE:VECTOR_AXPY_TYPE;
|
||||
else if(left == MATRIX_AXPY_TYPE || right == MATRIX_AXPY_TYPE) return MATRIX_AXPY_TYPE;
|
||||
else if(is_mmprod(left) || is_mmprod(right)) return MATRIX_AXPY_TYPE;
|
||||
if(left == GEMV_N_TYPE || right == GEMV_N_TYPE) return GEMV_N_TYPE;
|
||||
else if(left == GEMV_T_TYPE || right == GEMV_T_TYPE) return GEMV_T_TYPE;
|
||||
else if(left == DOT_TYPE || right == DOT_TYPE) return DOT_TYPE;
|
||||
else if(left == AXPY_TYPE || right == AXPY_TYPE) return op.type==OPERATOR_OUTER_PROD_TYPE?GER_TYPE:AXPY_TYPE;
|
||||
else if(left == GER_TYPE || right == GER_TYPE) return GER_TYPE;
|
||||
else if(is_mmprod(left) || is_mmprod(right)) return GER_TYPE;
|
||||
std::cout << left << " " << right << std::endl;
|
||||
throw;
|
||||
case OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY:
|
||||
return REDUCTION_TYPE;
|
||||
case OPERATOR_ROWS_REDUCTION_TYPE_FAMILY:
|
||||
return ROW_WISE_REDUCTION_TYPE;
|
||||
case OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY:
|
||||
return COL_WISE_REDUCTION_TYPE;
|
||||
case OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY:
|
||||
if(op.type==OPERATOR_MATRIX_PRODUCT_NN_TYPE) return MATRIX_PRODUCT_NN_TYPE;
|
||||
else if(op.type==OPERATOR_MATRIX_PRODUCT_TN_TYPE) return MATRIX_PRODUCT_TN_TYPE;
|
||||
else if(op.type==OPERATOR_MATRIX_PRODUCT_NT_TYPE) return MATRIX_PRODUCT_NT_TYPE;
|
||||
else return MATRIX_PRODUCT_TT_TYPE;
|
||||
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
|
||||
return DOT_TYPE;
|
||||
case OPERATOR_ROWS_DOT_TYPE_FAMILY:
|
||||
return GEMV_N_TYPE;
|
||||
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
|
||||
return GEMV_T_TYPE;
|
||||
case OPERATOR_GEMM_TYPE_FAMILY:
|
||||
if(op.type==OPERATOR_GEMM_NN_TYPE) return GEMM_NN_TYPE;
|
||||
else if(op.type==OPERATOR_GEMM_TN_TYPE) return GEMM_TN_TYPE;
|
||||
else if(op.type==OPERATOR_GEMM_NT_TYPE) return GEMM_NT_TYPE;
|
||||
else return GEMM_TT_TYPE;
|
||||
default:
|
||||
throw;
|
||||
}
|
||||
@@ -119,9 +119,9 @@ namespace isaac
|
||||
else if(node.lhs.subtype == DENSE_ARRAY_TYPE)
|
||||
{
|
||||
if(node.lhs.array->nshape()==1)
|
||||
type_left = VECTOR_AXPY_TYPE;
|
||||
type_left = AXPY_TYPE;
|
||||
else
|
||||
type_left = MATRIX_AXPY_TYPE;
|
||||
type_left = GER_TYPE;
|
||||
}
|
||||
|
||||
//Right
|
||||
@@ -131,9 +131,9 @@ namespace isaac
|
||||
else if(node.rhs.subtype == DENSE_ARRAY_TYPE)
|
||||
{
|
||||
if(node.rhs.array->nshape()==1)
|
||||
type_right = VECTOR_AXPY_TYPE;
|
||||
type_right = AXPY_TYPE;
|
||||
else
|
||||
type_right = MATRIX_AXPY_TYPE;
|
||||
type_right = GER_TYPE;
|
||||
}
|
||||
|
||||
|
||||
@@ -171,12 +171,10 @@ namespace isaac
|
||||
|
||||
//Init
|
||||
expression_type current_type;
|
||||
if(root_save.lhs.array->nshape()==0)
|
||||
current_type = SCALAR_AXPY_TYPE;
|
||||
else if(root_save.lhs.array->nshape()==1)
|
||||
current_type=VECTOR_AXPY_TYPE;
|
||||
if(root_save.lhs.array->nshape()<=1)
|
||||
current_type=AXPY_TYPE;
|
||||
else
|
||||
current_type=MATRIX_AXPY_TYPE;
|
||||
current_type=GER_TYPE;
|
||||
final_type = current_type;
|
||||
|
||||
/*----Parse required temporaries-----*/
|
||||
@@ -193,18 +191,17 @@ namespace isaac
|
||||
//Creates temporary
|
||||
tools::shared_ptr<array> tmp;
|
||||
switch(it->first){
|
||||
case SCALAR_AXPY_TYPE:
|
||||
case REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
|
||||
case DOT_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
|
||||
|
||||
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||
case AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||
case GEMV_N_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||
case GEMV_T_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||
|
||||
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
|
||||
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
|
||||
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
|
||||
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
|
||||
case GER_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||
case GEMM_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
|
||||
case GEMM_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
|
||||
case GEMM_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
|
||||
case GEMM_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
|
||||
|
||||
default: throw std::invalid_argument("Unrecognized operation");
|
||||
}
|
||||
|
Reference in New Issue
Block a user