More efficient access pattern in the GEMV kernel
This commit is contained in:
@@ -171,9 +171,9 @@ namespace detail
|
||||
else if(template_name=="maxpy")
|
||||
return tools::shared_ptr<base>(new maxpy(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||
else if(template_name.find("gemvN")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mreduction_rows(a[0], a[1], a[2], a[3], fetch[a[4]]));
|
||||
return tools::shared_ptr<base>(new mreduction_rows(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||
else if(template_name.find("gemvT")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mreduction_cols(a[0], a[1], a[2], a[3], fetch[a[4]]));
|
||||
return tools::shared_ptr<base>(new mreduction_cols(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||
else if(template_name.find("gemmNN")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mproduct_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
||||
else if(template_name.find("gemmTN")!=std::string::npos)
|
||||
@@ -247,8 +247,8 @@ model_map_t init_models(cl::CommandQueue & queue)
|
||||
res[std::make_pair(VECTOR_AXPY_TYPE, DTYPE)] = ptr_t (new model(vaxpy(1,64,128,FETCH_FROM_GLOBAL_STRIDED), queue));
|
||||
res[std::make_pair(REDUCTION_TYPE, DTYPE)] = ptr_t(new model(reduction(1,64,128,FETCH_FROM_GLOBAL_STRIDED), queue));
|
||||
res[std::make_pair(MATRIX_AXPY_TYPE, DTYPE)] = ptr_t(new model(maxpy(1,8,8,8,8,FETCH_FROM_GLOBAL_STRIDED), queue));
|
||||
res[std::make_pair(ROW_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new model(mreduction_rows(1, 8, 8, 16, FETCH_FROM_GLOBAL_STRIDED), queue));
|
||||
res[std::make_pair(COL_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new model(mreduction_cols(1, 8, 16, 128, FETCH_FROM_GLOBAL_STRIDED), queue));
|
||||
res[std::make_pair(ROW_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new model(mreduction_rows(1, 8, 8, 4, 16, FETCH_FROM_GLOBAL_STRIDED), queue));
|
||||
res[std::make_pair(COL_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new model(mreduction_cols(1, 8, 8, 64, 8, FETCH_FROM_GLOBAL_STRIDED), queue));
|
||||
res[std::make_pair(MATRIX_PRODUCT_NN_TYPE, DTYPE)] = ptr_t(new model(mproduct_nn(1, 8, 8, 8, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8), queue));
|
||||
res[std::make_pair(MATRIX_PRODUCT_TN_TYPE, DTYPE)] = ptr_t(new model(mproduct_tn(1, 8, 8, 8, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8), queue));
|
||||
res[std::make_pair(MATRIX_PRODUCT_NT_TYPE, DTYPE)] = ptr_t(new model(mproduct_nt(1, 8, 8, 8, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8), queue));
|
||||
|
Reference in New Issue
Block a user