Backend: GEMM - Improved bounds checking

This commit is contained in:
Philippe Tillet
2015-07-02 14:02:31 -04:00
parent 41204d6b74
commit 4c123c4b38
8 changed files with 70 additions and 86 deletions

View File

@@ -151,12 +151,12 @@ namespace detail
if(name=="vaxpy") return VECTOR_AXPY_TYPE;
if(name=="dot") return REDUCTION_TYPE;
if(name=="maxpy") return MATRIX_AXPY_TYPE;
if(name=="gemvN") return ROW_WISE_REDUCTION_TYPE;
if(name=="gemvT") return COL_WISE_REDUCTION_TYPE;
if(name=="gemmNN") return MATRIX_PRODUCT_NN_TYPE;
if(name=="gemmNT") return MATRIX_PRODUCT_NT_TYPE;
if(name=="gemmTN") return MATRIX_PRODUCT_TN_TYPE;
if(name=="gemmTT") return MATRIX_PRODUCT_TT_TYPE;
if(name=="mreduction_rows") return ROW_WISE_REDUCTION_TYPE;
if(name=="mreduction_cols") return COL_WISE_REDUCTION_TYPE;
if(name=="mproduct_nn") return MATRIX_PRODUCT_NN_TYPE;
if(name=="mproduct_nt") return MATRIX_PRODUCT_NT_TYPE;
if(name=="mproduct_tn") return MATRIX_PRODUCT_TN_TYPE;
if(name=="mproduct_tt") return MATRIX_PRODUCT_TT_TYPE;
throw std::invalid_argument("Invalid expression: " + name);
}
@@ -176,17 +176,17 @@ namespace detail
return tools::shared_ptr<base>(new reduction(a[0], a[1], a[2], fetch[a[3]]));
else if(template_name=="maxpy")
return tools::shared_ptr<base>(new maxpy(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemvN")!=std::string::npos)
else if(template_name.find("mreduction_rows")!=std::string::npos)
return tools::shared_ptr<base>(new mreduction_rows(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemvT")!=std::string::npos)
else if(template_name.find("mreduction_cols")!=std::string::npos)
return tools::shared_ptr<base>(new mreduction_cols(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemmNN")!=std::string::npos)
else if(template_name.find("mproduct_nn")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemmTN")!=std::string::npos)
else if(template_name.find("mproduct_tn")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemmNT")!=std::string::npos)
else if(template_name.find("mproduct_nt")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemmTT")!=std::string::npos)
else if(template_name.find("mproduct_tt")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else
throw std::invalid_argument("Invalid expression: " + template_name);
@@ -207,7 +207,7 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
str.assign((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
document.Parse<0>(str.c_str());
//Deserialize
std::vector<std::string> operations = {"vaxpy", "dot", "maxpy", "gemvN", "gemvT", "gemmNN", "gemmTN", "gemmNT", "gemmTT"};
std::vector<std::string> operations = {"vaxpy", "dot", "maxpy", "gemv_n", "gemv_t", "mproduct_nn", "mproduct_tn", "mproduct_nt", "mproduct_tt"};
std::vector<std::string> dtype = {"float32", "float64"};
for(auto & operation : operations)
{
@@ -256,10 +256,10 @@ std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<base> > ini
res[std::make_pair(MATRIX_AXPY_TYPE, DTYPE)] = ptr_t(new maxpy(1,8,8,8,8,FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(ROW_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new mreduction_rows(1, 8, 8, 4, 16, FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(COL_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new mreduction_cols(1, 8, 8, 64, 8, FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(MATRIX_PRODUCT_NN_TYPE, DTYPE)] = ptr_t(new mproduct_nn(1, 8, 8, 8, 1, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TN_TYPE, DTYPE)] = ptr_t(new mproduct_tn(1, 8, 8, 8, 1, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_NT_TYPE, DTYPE)] = ptr_t(new mproduct_nt(1, 8, 8, 8, 1, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TT_TYPE, DTYPE)] = ptr_t(new mproduct_tt(1, 8, 8, 8, 1, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_NN_TYPE, DTYPE)] = ptr_t(new mproduct_nn(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TN_TYPE, DTYPE)] = ptr_t(new mproduct_tn(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_NT_TYPE, DTYPE)] = ptr_t(new mproduct_nt(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TT_TYPE, DTYPE)] = ptr_t(new mproduct_tt(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
}
return res;
}