Various fixes
This commit is contained in:
@@ -142,14 +142,14 @@ namespace detail
|
||||
static expression_type get_expression_type(std::string const & name)
|
||||
{
|
||||
if(name=="vaxpy") return VECTOR_AXPY_TYPE;
|
||||
if(name=="reduction") return REDUCTION_TYPE;
|
||||
if(name=="dot") return REDUCTION_TYPE;
|
||||
if(name=="maxpy") return MATRIX_AXPY_TYPE;
|
||||
if(name=="row-wise-reductionN") return ROW_WISE_REDUCTION_TYPE;
|
||||
if(name=="row-wise-reductionT") return COL_WISE_REDUCTION_TYPE;
|
||||
if(name=="matrix-productNN") return MATRIX_PRODUCT_NN_TYPE;
|
||||
if(name=="matrix-productNT") return MATRIX_PRODUCT_NT_TYPE;
|
||||
if(name=="matrix-productTN") return MATRIX_PRODUCT_TN_TYPE;
|
||||
if(name=="matrix-productTT") return MATRIX_PRODUCT_TT_TYPE;
|
||||
if(name=="gemvN") return ROW_WISE_REDUCTION_TYPE;
|
||||
if(name=="gemvT") return COL_WISE_REDUCTION_TYPE;
|
||||
if(name=="gemmNN") return MATRIX_PRODUCT_NN_TYPE;
|
||||
if(name=="gemmNT") return MATRIX_PRODUCT_NT_TYPE;
|
||||
if(name=="gemmTN") return MATRIX_PRODUCT_TN_TYPE;
|
||||
if(name=="gemmTT") return MATRIX_PRODUCT_TT_TYPE;
|
||||
throw ;
|
||||
}
|
||||
|
||||
@@ -164,22 +164,23 @@ namespace detail
|
||||
{
|
||||
fetching_policy_type fetch[] = {FETCH_FROM_LOCAL, FETCH_FROM_GLOBAL_STRIDED, FETCH_FROM_GLOBAL_CONTIGUOUS};
|
||||
if(template_name=="vaxpy")
|
||||
return tools::shared_ptr<base>(new vaxpy( vaxpy_parameters(a[0], a[1], a[2], fetch[a[3]])));
|
||||
else if(template_name=="reduction")
|
||||
return tools::shared_ptr<base>(new reduction( reduction_parameters(a[0], a[1], a[2], fetch[a[3]])));
|
||||
return tools::shared_ptr<base>(new vaxpy(a[0], a[1], a[2], fetch[a[3]]));
|
||||
else if(template_name=="dot")
|
||||
return tools::shared_ptr<base>(new reduction(a[0], a[1], a[2], fetch[a[3]]));
|
||||
else if(template_name=="maxpy")
|
||||
return tools::shared_ptr<base>(new maxpy( maxpy_parameters(a[0], a[1], a[2], a[3], a[4], fetch[a[5]])));
|
||||
else if(template_name.find("row-wise-reduction")!=std::string::npos)
|
||||
{
|
||||
return tools::shared_ptr<base>(new mreduction_rows( mreduction_parameters(a[0], a[1], a[2], a[3], fetch[a[4]])));
|
||||
}
|
||||
else if(template_name.find("matrix-product")!=std::string::npos)
|
||||
{
|
||||
char A_trans = template_name[15];
|
||||
char B_trans = template_name[16];
|
||||
return tools::shared_ptr<base>(new mproduct( mproduct_parameters(a[0], a[1], a[2], a[3], a[4], a[5], a[6],
|
||||
fetch[a[7]], fetch[a[8]], a[9], a[10]), A_trans, B_trans));
|
||||
}
|
||||
return tools::shared_ptr<base>(new maxpy(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||
else if(template_name.find("gemvN")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mreduction_rows(a[0], a[1], a[2], a[3], fetch[a[4]]));
|
||||
else if(template_name.find("gemvT")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mreduction_cols(a[0], a[1], a[2], a[3], fetch[a[4]]));
|
||||
else if(template_name.find("gemmNN")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mproduct_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
||||
else if(template_name.find("gemmTN")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mproduct_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
||||
else if(template_name.find("gemmNT")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mproduct_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
||||
else if(template_name.find("gemmTT")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mproduct_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
||||
else
|
||||
throw operation_not_supported_exception("Cannot create the given operation");
|
||||
}
|
||||
@@ -198,9 +199,9 @@ void import(std::string const & fname, cl::CommandQueue & queue, model_map_t& re
|
||||
str.assign((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
|
||||
document.Parse<0>(str.c_str());
|
||||
//Deserialize
|
||||
std::vector<std::string> operations = tools::make_vector<std::string>() << "vaxpy" << "reduction"
|
||||
<< "maxpy" << "row-wise-reductionN" << "row-wise-reductionT"
|
||||
<< "matrix-productNN" << "matrix-productTN" << "matrix-productNT" << "matrix-productTT";
|
||||
std::vector<std::string> operations = tools::make_vector<std::string>() << "vaxpy" << "dot"
|
||||
<< "maxpy" << "gemvN" << "gemvT"
|
||||
<< "gemmNN" << "gemmTN" << "gemmTT";
|
||||
std::vector<std::string> dtype = tools::make_vector<std::string>() << "float32" << "float64";
|
||||
for(std::vector<std::string>::iterator op = operations.begin() ; op != operations.end() ; ++op)
|
||||
{
|
||||
|
Reference in New Issue
Block a user