Code quality: renamed expression types

This commit is contained in:
Philippe Tillet
2015-12-19 00:20:27 -05:00
parent acd460402d
commit b6d596d26d
7 changed files with 83 additions and 86 deletions

View File

@@ -18,13 +18,13 @@ namespace isaac
inline bool is_mmprod(expression_type x)
{
return x==GEMM_NN_TYPE || x==GEMM_NT_TYPE ||
x==GEMM_TN_TYPE || x==GEMM_TT_TYPE;
return x==MATRIX_PRODUCT_NN || x==MATRIX_PRODUCT_NT ||
x==MATRIX_PRODUCT_TN || x==MATRIX_PRODUCT_TT;
}
inline bool is_mvprod(expression_type x)
{
return x==GEMV_N_TYPE || x==GEMV_T_TYPE;
return x==REDUCE_2D_ROWS || x==REDUCE_2D_COLS;
}
inline bool has_temporary_impl(op_element op, expression_type expression, expression_type other, bool is_first)
@@ -35,27 +35,27 @@ namespace isaac
case OPERATOR_UNARY_TYPE_FAMILY:
case OPERATOR_BINARY_TYPE_FAMILY:
result |= is_mmprod(expression)
|| (result |= expression==GEMV_N_TYPE && other==GEMV_T_TYPE)
|| (result |= expression==GEMV_T_TYPE && other==GEMV_N_TYPE);
|| (result |= expression==REDUCE_2D_ROWS && other==REDUCE_2D_COLS)
|| (result |= expression==REDUCE_2D_COLS && other==REDUCE_2D_ROWS);
break;
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
result |= is_mvprod(expression)
|| expression==DOT_TYPE;
|| expression==REDUCE_1D;
break;
case OPERATOR_ROWS_DOT_TYPE_FAMILY:
result |= is_mmprod(expression)
|| is_mvprod(expression)
|| expression==DOT_TYPE;
|| expression==REDUCE_1D;
break;
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
result |= is_mmprod(expression)
|| is_mvprod(expression)
|| expression==DOT_TYPE;
|| expression==REDUCE_1D;
break;
case OPERATOR_GEMM_TYPE_FAMILY:
result |= (is_mmprod(expression) && !is_first)
|| is_mvprod(expression)
|| expression==DOT_TYPE;
|| expression==REDUCE_1D;
break;
default:
break;
@@ -76,29 +76,29 @@ namespace isaac
{
case OPERATOR_UNARY_TYPE_FAMILY:
if(is_mmprod(left))
return GER_TYPE;
return ELEMENTWISE_2D;
return left;
case OPERATOR_BINARY_TYPE_FAMILY:
if(left == GEMV_N_TYPE || right == GEMV_N_TYPE) return GEMV_N_TYPE;
else if(left == GEMV_T_TYPE || right == GEMV_T_TYPE) return GEMV_T_TYPE;
else if(left == DOT_TYPE || right == DOT_TYPE) return DOT_TYPE;
else if(left == GER_TYPE || right == GER_TYPE) return GER_TYPE;
else if(left == AXPY_TYPE || right == AXPY_TYPE) return op.type==OPERATOR_OUTER_PROD_TYPE?GER_TYPE:AXPY_TYPE;
else if(is_mmprod(left) || is_mmprod(right)) return GER_TYPE;
if(left == REDUCE_2D_ROWS || right == REDUCE_2D_ROWS) return REDUCE_2D_ROWS;
else if(left == REDUCE_2D_COLS || right == REDUCE_2D_COLS) return REDUCE_2D_COLS;
else if(left == REDUCE_1D || right == REDUCE_1D) return REDUCE_1D;
else if(left == ELEMENTWISE_2D || right == ELEMENTWISE_2D) return ELEMENTWISE_2D;
else if(left == ELEMENTWISE_1D || right == ELEMENTWISE_1D) return op.type==OPERATOR_OUTER_PROD_TYPE?ELEMENTWISE_2D:ELEMENTWISE_1D;
else if(is_mmprod(left) || is_mmprod(right)) return ELEMENTWISE_2D;
else if(right == INVALID_EXPRESSION_TYPE) return left;
else if(left == INVALID_EXPRESSION_TYPE) return right;
throw;
case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
return DOT_TYPE;
return REDUCE_1D;
case OPERATOR_ROWS_DOT_TYPE_FAMILY:
return GEMV_N_TYPE;
return REDUCE_2D_ROWS;
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
return GEMV_T_TYPE;
return REDUCE_2D_COLS;
case OPERATOR_GEMM_TYPE_FAMILY:
if(op.type==OPERATOR_GEMM_NN_TYPE) return GEMM_NN_TYPE;
else if(op.type==OPERATOR_GEMM_TN_TYPE) return GEMM_TN_TYPE;
else if(op.type==OPERATOR_GEMM_NT_TYPE) return GEMM_NT_TYPE;
else return GEMM_TT_TYPE;
if(op.type==OPERATOR_GEMM_NN_TYPE) return MATRIX_PRODUCT_NN;
else if(op.type==OPERATOR_GEMM_TN_TYPE) return MATRIX_PRODUCT_TN;
else if(op.type==OPERATOR_GEMM_NT_TYPE) return MATRIX_PRODUCT_NT;
else return MATRIX_PRODUCT_TT;
default:
throw;
}
@@ -120,9 +120,9 @@ namespace isaac
else if(node.lhs.subtype == DENSE_ARRAY_TYPE)
{
if(node.op.type==OPERATOR_MATRIX_ROW_TYPE || node.op.type==OPERATOR_MATRIX_COLUMN_TYPE || ng1(node.lhs.array->shape())<=1)
type_left = AXPY_TYPE;
type_left = ELEMENTWISE_1D;
else
type_left = GER_TYPE;
type_left = ELEMENTWISE_2D;
}
//Right
@@ -132,9 +132,9 @@ namespace isaac
else if(node.rhs.subtype == DENSE_ARRAY_TYPE)
{
if(node.op.type==OPERATOR_MATRIX_ROW_TYPE || node.op.type==OPERATOR_MATRIX_COLUMN_TYPE || ng1(node.rhs.array->shape())<=1)
type_right = AXPY_TYPE;
type_right = ELEMENTWISE_1D;
else
type_right = GER_TYPE;
type_right = ELEMENTWISE_2D;
}
final_type = merge(array[idx].op, type_left, type_right);
@@ -174,9 +174,9 @@ namespace isaac
expression_type current_type;
auto ng1 = [](shape_t const & shape){ size_t res = 0 ; for(size_t i = 0 ; i < shape.size() ; ++i) res += (shape[i] > 1); return res;};
if(ng1(expression.shape())<=1)
current_type=AXPY_TYPE;
current_type=ELEMENTWISE_1D;
else
current_type=GER_TYPE;
current_type=ELEMENTWISE_2D;
final_type = current_type;
/*----Parse required temporaries-----*/
@@ -192,17 +192,17 @@ namespace isaac
//Creates temporary
std::shared_ptr<array> tmp;
switch(it->first){
case DOT_TYPE: tmp = std::shared_ptr<array>(new array(1, dtype, context)); break;
case REDUCE_1D: tmp = std::shared_ptr<array>(new array(1, dtype, context)); break;
case AXPY_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case GEMV_N_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case GEMV_T_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
case ELEMENTWISE_1D: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case REDUCE_2D_ROWS: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case REDUCE_2D_COLS: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
case GER_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
case GEMM_NN_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
case GEMM_NT_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
case GEMM_TN_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
case GEMM_TT_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
case ELEMENTWISE_2D: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_NN: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_NT: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
case MATRIX_PRODUCT_TN: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
case MATRIX_PRODUCT_TT: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
default: throw std::invalid_argument("Unrecognized operation");
}