Code quality: renamed expression types

This commit is contained in:
Philippe Tillet
2015-12-19 00:20:27 -05:00
parent acd460402d
commit b6d596d26d
7 changed files with 83 additions and 86 deletions

View File

@@ -10,28 +10,28 @@ namespace isaac
enum expression_type enum expression_type
{ {
INVALID_EXPRESSION_TYPE, INVALID_EXPRESSION_TYPE,
AXPY_TYPE, ELEMENTWISE_1D,
GER_TYPE, ELEMENTWISE_2D,
DOT_TYPE, REDUCE_1D,
GEMV_N_TYPE, REDUCE_2D_ROWS,
GEMV_T_TYPE, REDUCE_2D_COLS,
GEMM_NN_TYPE, MATRIX_PRODUCT_NN,
GEMM_TN_TYPE, MATRIX_PRODUCT_TN,
GEMM_NT_TYPE, MATRIX_PRODUCT_NT,
GEMM_TT_TYPE MATRIX_PRODUCT_TT
}; };
inline expression_type expression_type_from_string(std::string const & name) inline expression_type expression_type_from_string(std::string const & name)
{ {
if(name=="elementwise_1d") return AXPY_TYPE; if(name=="elementwise_1d") return ELEMENTWISE_1D;
if(name=="reduce_1d") return DOT_TYPE; if(name=="reduce_1d") return REDUCE_1D;
if(name=="elementwise_2d") return GER_TYPE; if(name=="elementwise_2d") return ELEMENTWISE_2D;
if(name=="reduce_2d_rows") return GEMV_N_TYPE; if(name=="reduce_2d_rows") return REDUCE_2D_ROWS;
if(name=="reduce_2d_cols") return GEMV_T_TYPE; if(name=="reduce_2d_cols") return REDUCE_2D_COLS;
if(name=="matrix_product_nn") return GEMM_NN_TYPE; if(name=="matrix_product_nn") return MATRIX_PRODUCT_NN;
if(name=="matrix_product_nt") return GEMM_NT_TYPE; if(name=="matrix_product_nt") return MATRIX_PRODUCT_NT;
if(name=="matrix_product_tn") return GEMM_TN_TYPE; if(name=="matrix_product_tn") return MATRIX_PRODUCT_TN;
if(name=="matrix_product_tt") return GEMM_TT_TYPE; if(name=="matrix_product_tt") return MATRIX_PRODUCT_TT;
throw std::invalid_argument("Unrecognized expression: " + name); throw std::invalid_argument("Unrecognized expression: " + name);
} }

View File

@@ -76,7 +76,6 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, math_expre
math_expression::container_type const & tree = expressions.tree(); math_expression::container_type const & tree = expressions.tree();
std::vector<std::size_t> sfors = filter_nodes([](math_expression::node const & node){return node.op.type==OPERATOR_SFOR_TYPE;}, expressions, expressions.root(), true); std::vector<std::size_t> sfors = filter_nodes([](math_expression::node const & node){return node.op.type==OPERATOR_SFOR_TYPE;}, expressions, expressions.root(), true);
// std::cout << sfors.size() << std::endl;
for(unsigned int i = 0 ; i < sfors.size() ; ++i) for(unsigned int i = 0 ; i < sfors.size() ; ++i)
{ {
@@ -90,8 +89,6 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, math_expre
info[2] = evaluate(RHS_NODE_TYPE, {{"placeholder", "#name"}}, expressions, idx, mappings); info[2] = evaluate(RHS_NODE_TYPE, {{"placeholder", "#name"}}, expressions, idx, mappings);
info[0] = info[0].substr(1, info[0].size()-2); info[0] = info[0].substr(1, info[0].size()-2);
stream << "for(int " << info[0] << " ; " << info[1] << "; " << info[2] << ")" << std::endl; stream << "for(int " << info[0] << " ; " << info[1] << "; " << info[2] << ")" << std::endl;
// stream << "int sforidx0 = 0 ;" << std::endl;
} }
if(sfors.size()){ if(sfors.size()){

View File

@@ -664,10 +664,10 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
matrix_product::matrix_product(matrix_product_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<matrix_product, matrix_product_parameters>(parameters, BIND_INDEPENDENT), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds) matrix_product::matrix_product(matrix_product_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<matrix_product, matrix_product_parameters>(parameters, BIND_INDEPENDENT), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
{ {
if(A_trans_=='N' && B_trans_=='N') type_ = GEMM_NN_TYPE; if(A_trans_=='N' && B_trans_=='N') type_ = MATRIX_PRODUCT_NN;
else if(A_trans_=='T' && B_trans_=='N') type_ = GEMM_TN_TYPE; else if(A_trans_=='T' && B_trans_=='N') type_ = MATRIX_PRODUCT_TN;
else if(A_trans_=='N' && B_trans_=='T') type_ = GEMM_NT_TYPE; else if(A_trans_=='N' && B_trans_=='T') type_ = MATRIX_PRODUCT_NT;
else if(A_trans_=='T' && B_trans_=='T') type_ = GEMM_TT_TYPE; else if(A_trans_=='T' && B_trans_=='T') type_ = MATRIX_PRODUCT_TT;
else throw; else throw;
} }

View File

@@ -201,7 +201,7 @@ profiles::map_type& profiles::init(driver::CommandQueue const & queue)
map_type & result = cache_[queue]; map_type & result = cache_[queue];
numeric_type dtypes[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE}; numeric_type dtypes[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
expression_type etypes[] = {AXPY_TYPE, DOT_TYPE, GER_TYPE, GEMV_N_TYPE, GEMV_T_TYPE, GEMM_NN_TYPE, GEMM_NT_TYPE, GEMM_TN_TYPE, GEMM_TT_TYPE}; expression_type etypes[] = {ELEMENTWISE_1D, REDUCE_1D, ELEMENTWISE_2D, REDUCE_2D_ROWS, REDUCE_2D_COLS, MATRIX_PRODUCT_NN, MATRIX_PRODUCT_NT, MATRIX_PRODUCT_TN, MATRIX_PRODUCT_TT};
for(numeric_type dtype: dtypes) for(numeric_type dtype: dtypes)
for(expression_type etype: etypes) for(expression_type etype: etypes)
@@ -265,15 +265,15 @@ std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::ba
numeric_type types[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE}; numeric_type types[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
for(auto DTYPE : types) for(auto DTYPE : types)
{ {
res[std::make_pair(AXPY_TYPE, DTYPE)] = ptr_t (new templates::elementwise_1d(1,64,128,templates::FETCH_FROM_GLOBAL_STRIDED)); res[std::make_pair(ELEMENTWISE_1D, DTYPE)] = ptr_t (new templates::elementwise_1d(1,64,128,templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(DOT_TYPE, DTYPE)] = ptr_t(new templates::reduce_1d(1,64,128,templates::FETCH_FROM_GLOBAL_STRIDED)); res[std::make_pair(REDUCE_1D, DTYPE)] = ptr_t(new templates::reduce_1d(1,64,128,templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(GER_TYPE, DTYPE)] = ptr_t(new templates::elementwise_2d(1,128,1,16,32,templates::FETCH_FROM_GLOBAL_STRIDED)); res[std::make_pair(ELEMENTWISE_2D, DTYPE)] = ptr_t(new templates::elementwise_2d(1,128,1,16,32,templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(GEMV_N_TYPE, DTYPE)] = ptr_t(new templates::reduce_2d_rows(1, 8, 8, 4, 16, templates::FETCH_FROM_GLOBAL_STRIDED)); res[std::make_pair(REDUCE_2D_ROWS, DTYPE)] = ptr_t(new templates::reduce_2d_rows(1, 8, 8, 4, 16, templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(GEMV_T_TYPE, DTYPE)] = ptr_t(new templates::reduce_2d_cols(1, 8, 8, 64, 8, templates::FETCH_FROM_GLOBAL_STRIDED)); res[std::make_pair(REDUCE_2D_COLS, DTYPE)] = ptr_t(new templates::reduce_2d_cols(1, 8, 8, 64, 8, templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(GEMM_NN_TYPE, DTYPE)] = ptr_t(new templates::matrix_product_nn(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true)); res[std::make_pair(MATRIX_PRODUCT_NN, DTYPE)] = ptr_t(new templates::matrix_product_nn(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(GEMM_TN_TYPE, DTYPE)] = ptr_t(new templates::matrix_product_tn(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true)); res[std::make_pair(MATRIX_PRODUCT_TN, DTYPE)] = ptr_t(new templates::matrix_product_tn(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(GEMM_NT_TYPE, DTYPE)] = ptr_t(new templates::matrix_product_nt(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true)); res[std::make_pair(MATRIX_PRODUCT_NT, DTYPE)] = ptr_t(new templates::matrix_product_nt(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(GEMM_TT_TYPE, DTYPE)] = ptr_t(new templates::matrix_product_tt(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true)); res[std::make_pair(MATRIX_PRODUCT_TT, DTYPE)] = ptr_t(new templates::matrix_product_tt(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
} }
return res; return res;
} }

View File

@@ -18,13 +18,13 @@ namespace isaac
inline bool is_mmprod(expression_type x) inline bool is_mmprod(expression_type x)
{ {
return x==GEMM_NN_TYPE || x==GEMM_NT_TYPE || return x==MATRIX_PRODUCT_NN || x==MATRIX_PRODUCT_NT ||
x==GEMM_TN_TYPE || x==GEMM_TT_TYPE; x==MATRIX_PRODUCT_TN || x==MATRIX_PRODUCT_TT;
} }
inline bool is_mvprod(expression_type x) inline bool is_mvprod(expression_type x)
{ {
return x==GEMV_N_TYPE || x==GEMV_T_TYPE; return x==REDUCE_2D_ROWS || x==REDUCE_2D_COLS;
} }
inline bool has_temporary_impl(op_element op, expression_type expression, expression_type other, bool is_first) inline bool has_temporary_impl(op_element op, expression_type expression, expression_type other, bool is_first)
@@ -35,27 +35,27 @@ namespace isaac
case OPERATOR_UNARY_TYPE_FAMILY: case OPERATOR_UNARY_TYPE_FAMILY:
case OPERATOR_BINARY_TYPE_FAMILY: case OPERATOR_BINARY_TYPE_FAMILY:
result |= is_mmprod(expression) result |= is_mmprod(expression)
|| (result |= expression==GEMV_N_TYPE && other==GEMV_T_TYPE) || (result |= expression==REDUCE_2D_ROWS && other==REDUCE_2D_COLS)
|| (result |= expression==GEMV_T_TYPE && other==GEMV_N_TYPE); || (result |= expression==REDUCE_2D_COLS && other==REDUCE_2D_ROWS);
break; break;
case OPERATOR_VECTOR_DOT_TYPE_FAMILY: case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
result |= is_mvprod(expression) result |= is_mvprod(expression)
|| expression==DOT_TYPE; || expression==REDUCE_1D;
break; break;
case OPERATOR_ROWS_DOT_TYPE_FAMILY: case OPERATOR_ROWS_DOT_TYPE_FAMILY:
result |= is_mmprod(expression) result |= is_mmprod(expression)
|| is_mvprod(expression) || is_mvprod(expression)
|| expression==DOT_TYPE; || expression==REDUCE_1D;
break; break;
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY: case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
result |= is_mmprod(expression) result |= is_mmprod(expression)
|| is_mvprod(expression) || is_mvprod(expression)
|| expression==DOT_TYPE; || expression==REDUCE_1D;
break; break;
case OPERATOR_GEMM_TYPE_FAMILY: case OPERATOR_GEMM_TYPE_FAMILY:
result |= (is_mmprod(expression) && !is_first) result |= (is_mmprod(expression) && !is_first)
|| is_mvprod(expression) || is_mvprod(expression)
|| expression==DOT_TYPE; || expression==REDUCE_1D;
break; break;
default: default:
break; break;
@@ -76,29 +76,29 @@ namespace isaac
{ {
case OPERATOR_UNARY_TYPE_FAMILY: case OPERATOR_UNARY_TYPE_FAMILY:
if(is_mmprod(left)) if(is_mmprod(left))
return GER_TYPE; return ELEMENTWISE_2D;
return left; return left;
case OPERATOR_BINARY_TYPE_FAMILY: case OPERATOR_BINARY_TYPE_FAMILY:
if(left == GEMV_N_TYPE || right == GEMV_N_TYPE) return GEMV_N_TYPE; if(left == REDUCE_2D_ROWS || right == REDUCE_2D_ROWS) return REDUCE_2D_ROWS;
else if(left == GEMV_T_TYPE || right == GEMV_T_TYPE) return GEMV_T_TYPE; else if(left == REDUCE_2D_COLS || right == REDUCE_2D_COLS) return REDUCE_2D_COLS;
else if(left == DOT_TYPE || right == DOT_TYPE) return DOT_TYPE; else if(left == REDUCE_1D || right == REDUCE_1D) return REDUCE_1D;
else if(left == GER_TYPE || right == GER_TYPE) return GER_TYPE; else if(left == ELEMENTWISE_2D || right == ELEMENTWISE_2D) return ELEMENTWISE_2D;
else if(left == AXPY_TYPE || right == AXPY_TYPE) return op.type==OPERATOR_OUTER_PROD_TYPE?GER_TYPE:AXPY_TYPE; else if(left == ELEMENTWISE_1D || right == ELEMENTWISE_1D) return op.type==OPERATOR_OUTER_PROD_TYPE?ELEMENTWISE_2D:ELEMENTWISE_1D;
else if(is_mmprod(left) || is_mmprod(right)) return GER_TYPE; else if(is_mmprod(left) || is_mmprod(right)) return ELEMENTWISE_2D;
else if(right == INVALID_EXPRESSION_TYPE) return left; else if(right == INVALID_EXPRESSION_TYPE) return left;
else if(left == INVALID_EXPRESSION_TYPE) return right; else if(left == INVALID_EXPRESSION_TYPE) return right;
throw; throw;
case OPERATOR_VECTOR_DOT_TYPE_FAMILY: case OPERATOR_VECTOR_DOT_TYPE_FAMILY:
return DOT_TYPE; return REDUCE_1D;
case OPERATOR_ROWS_DOT_TYPE_FAMILY: case OPERATOR_ROWS_DOT_TYPE_FAMILY:
return GEMV_N_TYPE; return REDUCE_2D_ROWS;
case OPERATOR_COLUMNS_DOT_TYPE_FAMILY: case OPERATOR_COLUMNS_DOT_TYPE_FAMILY:
return GEMV_T_TYPE; return REDUCE_2D_COLS;
case OPERATOR_GEMM_TYPE_FAMILY: case OPERATOR_GEMM_TYPE_FAMILY:
if(op.type==OPERATOR_GEMM_NN_TYPE) return GEMM_NN_TYPE; if(op.type==OPERATOR_GEMM_NN_TYPE) return MATRIX_PRODUCT_NN;
else if(op.type==OPERATOR_GEMM_TN_TYPE) return GEMM_TN_TYPE; else if(op.type==OPERATOR_GEMM_TN_TYPE) return MATRIX_PRODUCT_TN;
else if(op.type==OPERATOR_GEMM_NT_TYPE) return GEMM_NT_TYPE; else if(op.type==OPERATOR_GEMM_NT_TYPE) return MATRIX_PRODUCT_NT;
else return GEMM_TT_TYPE; else return MATRIX_PRODUCT_TT;
default: default:
throw; throw;
} }
@@ -120,9 +120,9 @@ namespace isaac
else if(node.lhs.subtype == DENSE_ARRAY_TYPE) else if(node.lhs.subtype == DENSE_ARRAY_TYPE)
{ {
if(node.op.type==OPERATOR_MATRIX_ROW_TYPE || node.op.type==OPERATOR_MATRIX_COLUMN_TYPE || ng1(node.lhs.array->shape())<=1) if(node.op.type==OPERATOR_MATRIX_ROW_TYPE || node.op.type==OPERATOR_MATRIX_COLUMN_TYPE || ng1(node.lhs.array->shape())<=1)
type_left = AXPY_TYPE; type_left = ELEMENTWISE_1D;
else else
type_left = GER_TYPE; type_left = ELEMENTWISE_2D;
} }
//Right //Right
@@ -132,9 +132,9 @@ namespace isaac
else if(node.rhs.subtype == DENSE_ARRAY_TYPE) else if(node.rhs.subtype == DENSE_ARRAY_TYPE)
{ {
if(node.op.type==OPERATOR_MATRIX_ROW_TYPE || node.op.type==OPERATOR_MATRIX_COLUMN_TYPE || ng1(node.rhs.array->shape())<=1) if(node.op.type==OPERATOR_MATRIX_ROW_TYPE || node.op.type==OPERATOR_MATRIX_COLUMN_TYPE || ng1(node.rhs.array->shape())<=1)
type_right = AXPY_TYPE; type_right = ELEMENTWISE_1D;
else else
type_right = GER_TYPE; type_right = ELEMENTWISE_2D;
} }
final_type = merge(array[idx].op, type_left, type_right); final_type = merge(array[idx].op, type_left, type_right);
@@ -174,9 +174,9 @@ namespace isaac
expression_type current_type; expression_type current_type;
auto ng1 = [](shape_t const & shape){ size_t res = 0 ; for(size_t i = 0 ; i < shape.size() ; ++i) res += (shape[i] > 1); return res;}; auto ng1 = [](shape_t const & shape){ size_t res = 0 ; for(size_t i = 0 ; i < shape.size() ; ++i) res += (shape[i] > 1); return res;};
if(ng1(expression.shape())<=1) if(ng1(expression.shape())<=1)
current_type=AXPY_TYPE; current_type=ELEMENTWISE_1D;
else else
current_type=GER_TYPE; current_type=ELEMENTWISE_2D;
final_type = current_type; final_type = current_type;
/*----Parse required temporaries-----*/ /*----Parse required temporaries-----*/
@@ -192,17 +192,17 @@ namespace isaac
//Creates temporary //Creates temporary
std::shared_ptr<array> tmp; std::shared_ptr<array> tmp;
switch(it->first){ switch(it->first){
case DOT_TYPE: tmp = std::shared_ptr<array>(new array(1, dtype, context)); break; case REDUCE_1D: tmp = std::shared_ptr<array>(new array(1, dtype, context)); break;
case AXPY_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break; case ELEMENTWISE_1D: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case GEMV_N_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break; case REDUCE_2D_ROWS: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case GEMV_T_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break; case REDUCE_2D_COLS: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
case GER_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break; case ELEMENTWISE_2D: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
case GEMM_NN_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break; case MATRIX_PRODUCT_NN: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
case GEMM_NT_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break; case MATRIX_PRODUCT_NT: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
case GEMM_TN_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break; case MATRIX_PRODUCT_TN: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
case GEMM_TT_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break; case MATRIX_PRODUCT_TT: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
default: throw std::invalid_argument("Unrecognized operation"); default: throw std::invalid_argument("Unrecognized operation");
} }

View File

@@ -18,10 +18,10 @@ void matrix_product::handle_node(math_expression::container_type const & tree, s
if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs; if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs;
switch(tree[rootidx].op.type) switch(tree[rootidx].op.type)
{ {
case OPERATOR_GEMM_NN_TYPE: a.type = GEMM_NN_TYPE; break; case OPERATOR_GEMM_NN_TYPE: a.type = MATRIX_PRODUCT_NN; break;
case OPERATOR_GEMM_NT_TYPE: a.type = GEMM_NT_TYPE; break; case OPERATOR_GEMM_NT_TYPE: a.type = MATRIX_PRODUCT_NT; break;
case OPERATOR_GEMM_TN_TYPE: a.type = GEMM_TN_TYPE; break; case OPERATOR_GEMM_TN_TYPE: a.type = MATRIX_PRODUCT_TN; break;
case OPERATOR_GEMM_TT_TYPE: a.type = GEMM_TT_TYPE; break; case OPERATOR_GEMM_TT_TYPE: a.type = MATRIX_PRODUCT_TT; break;
default: break; default: break;
} }
} }

View File

@@ -70,15 +70,15 @@ namespace tools
else else
name = bp::extract<std::string>(odtype.attr("__class__").attr("__name__"))(); name = bp::extract<std::string>(odtype.attr("__class__").attr("__name__"))();
if(name=="elementwise_1d") return sc::AXPY_TYPE; if(name=="elementwise_1d") return sc::ELEMENTWISE_1D;
else if(name=="elementwise_2d") return sc::GER_TYPE; else if(name=="elementwise_2d") return sc::ELEMENTWISE_2D;
else if(name=="reduce_1d") return sc::DOT_TYPE; else if(name=="reduce_1d") return sc::REDUCE_1D;
else if(name=="reduce_2d_rows") return sc::GEMV_N_TYPE; else if(name=="reduce_2d_rows") return sc::REDUCE_2D_ROWS;
else if(name=="reduce_2d_cols") return sc::GEMV_T_TYPE; else if(name=="reduce_2d_cols") return sc::REDUCE_2D_COLS;
else if(name=="matrix_product_nn") return sc::GEMM_NN_TYPE; else if(name=="matrix_product_nn") return sc::MATRIX_PRODUCT_NN;
else if(name=="matrix_product_tn") return sc::GEMM_TN_TYPE; else if(name=="matrix_product_tn") return sc::MATRIX_PRODUCT_TN;
else if(name=="matrix_product_nt") return sc::GEMM_NT_TYPE; else if(name=="matrix_product_nt") return sc::MATRIX_PRODUCT_NT;
else if(name=="matrix_product_tt") return sc::GEMM_TT_TYPE; else if(name=="matrix_product_tt") return sc::MATRIX_PRODUCT_TT;
else else
{ {
PyErr_SetString(PyExc_TypeError, "Template type not understood"); PyErr_SetString(PyExc_TypeError, "Template type not understood");