Backend: GEMM - Improved bounds checking

This commit is contained in:
Philippe Tillet
2015-07-02 14:02:31 -04:00
parent 41204d6b74
commit 4c123c4b38
8 changed files with 70 additions and 86 deletions

View File

@@ -11,12 +11,12 @@ if(CUDA_FOUND)
endif()
#CLAMDBLAS
#find_package(CLAMDBLAS)
#if(CLAMDBLAS_FOUND)
find_package(CLAMDBLAS)
if(CLAMDBLAS_FOUND)
set(BLAS_DEF ${BLAS_DEF} "-DBENCH_CLBLAS")
#include_directories(${CLAMDBLAS_INCLUDE_DIR})
#set(BLAS_LIBS ${BLAS_LIBS} ${CLAMDBLAS_LIBRARIES} )
#endif()
include_directories(${CLAMDBLAS_INCLUDE_DIR})
set(BLAS_LIBS ${BLAS_LIBS} ${CLAMDBLAS_LIBRARIES} )
endif()
#CBLAS
find_package(MKL)
@@ -25,12 +25,12 @@ if(MKL_FOUND)
include_directories(${MKL_INCLUDE_DIR})
set(BLAS_LIBS ${BLAS_LIBS} ${MKL_LIBRARIES} )
else()
find_package(OpenBlas)
if(OPENBLAS_FOUND)
set(BLAS_DEF ${BLAS_DEF} "-DBENCH_CBLAS")
include_directories(${OPENBLAS_INCLUDE_DIR})
set(BLAS_LIBS ${BLAS_LIBS} ${OPENBLAS_LIBRARIES} )
endif()
# find_package(OpenBlas)
# if(OPENBLAS_FOUND)
# set(BLAS_DEF ${BLAS_DEF} "-DBENCH_CBLAS")
# include_directories(${OPENBLAS_INCLUDE_DIR})
# set(BLAS_LIBS ${BLAS_LIBS} ${OPENBLAS_LIBRARIES} )
# endif()
endif()
string(REPLACE ";" " " BLAS_DEF_STR "${BLAS_DEF}")

View File

@@ -3,7 +3,7 @@ file(GLOB CLAMDBLAS_ROOT /opt/clBlas*)
set(CLAMDBLAS_INCLUDE_HINTS "${CLAMDBLAS_ROOT}/include")
set(CLAMDBLAS_LIBRARIES_HINTS "${CLAMDBLAS_ROOT}/lib64")
find_path(CLAMDBLAS_INCLUDE_DIR clAmdBlas.h HINTS ${CLAMDBLAS_INCLUDE_HINTS})
find_path(CLAMDBLAS_INCLUDE_DIR clBLAS.h HINTS ${CLAMDBLAS_INCLUDE_HINTS})
find_library(CLAMDBLAS_LIBRARIES NAMES clBLAS HINTS ${CLAMDBLAS_LIBRARIES_HINTS})
if(CLAMDBLAS_LIBRARIES)

View File

@@ -380,10 +380,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
stream << "//Inner loop" << std::endl;
if (check_bounds_)
stream << "for(unsigned int k = 0; k < " << p_.kL << " && (block_k + k < chunk_size); k+=" << p_.kS << "){" << std::endl;
else
stream << "for(unsigned int k = 0; k < " << p_.kL << "; k+=" << p_.kS << "){" << std::endl;
stream << "for(unsigned int k = 0; k < " << p_.kL << " && (block_k + k < chunk_size); k+=" << p_.kS << "){" << std::endl;
stream.inc_tab();
stream << "//Fetch A to registers" << std::endl;
@@ -504,7 +501,9 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
stream << "//Write back C" << std::endl;
unsigned int ministartstride0 = p_.A_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?p_.mS:p_.simd_width;
unsigned int ministartstride1 = p_.B_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?p_.nS:p_.simd_width;
stream << "C += (gidx*" << p_.mL << " + idx*" << ministartstride0 << ")" << CSTRIDE1 << " + (gidy*" << p_.nL << " + idy*" << ministartstride1 << ")*Cld + gidz*Cld*N;" << std::endl;
stream << "size_t offx = (gidx*" << p_.mL << " + idx*" << ministartstride0 << ")" << ";" << std::endl;
stream << "size_t offy = (gidy*" << p_.nL << " + idy*" << ministartstride1 << ");" << std::endl;
stream << "C += " << "offx" << CSTRIDE1 << " + offy*Cld + gidz*Cld*N;" << std::endl;
for(int_t m=0; m < p_.mS; ++m)
for(int_t n=0; n < p_.nS; ++n)
{
@@ -513,8 +512,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
string Ci = to_string((m/p_.simd_width)*(ministride0*p_.simd_width) + m%p_.simd_width);
string Cj = to_string((n/p_.simd_width)*(ministride1*p_.simd_width) + n%p_.simd_width);
if (check_bounds_)
stream << "if (in_bounds_m[" << m << "] && in_bounds_n[" << n << "]) " ;
stream << "if((offx + " << Ci << ")<M && (" << Cj << " + offy)<N)"<< std::flush;
stream << "C[" << Ci << CSTRIDE1 << " + " << Cj << "*Cld] = rC[" << m << "][" << n << "]*alpha + ((beta==0)?0:beta*C[" << Ci << " + " << Cj << "*Cld]);" << std::endl;
}
stream.dec_tab();
@@ -589,7 +587,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
using tools::align;
driver::NDRange global = (strcmp(suffix,"fallback")==0)?driver::NDRange(align(align(M,p_.mS)/p_.mS, p_.local_size_0), align(align(N,p_.nS)/p_.nS, p_.local_size_1), p_.depth):driver::NDRange(M/p_.mS, N/p_.nS, p_.depth);
driver::NDRange global(align(align(M,p_.mS)/p_.mS, p_.local_size_0), align(align(N,p_.nS)/p_.nS, p_.local_size_1), p_.depth);
unsigned int current_arg = 0;
set_arguments_functor helper(binder, current_arg, gemm);
@@ -612,20 +610,18 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
gemm.setSizeArg(current_arg++, B.start()[0] + B.start()[1]*B.ld()/p_.simd_width);
gemm.setSizeArg(current_arg++, B.stride()[0]);
// std::cout << "before " << *out << std::endl;
helper.set_arguments(beta.dtype(), beta.values());
options.enqueue(program.context(), gemm, global, local);
options.queue(program.context()).synchronize();
// std::cout << "after " << *out << std::endl;
if(p_.depth > 1)
{
unsigned int current_arg = 0;
driver::Kernel reduce(program, reduce_name);
driver::NDRange local(p_.local_size_0, p_.local_size_1);
driver::NDRange global = driver::NDRange(M, N);
driver::NDRange global(align(M, p_.local_size_0), align(N, p_.local_size_1));
set_arguments_functor helper(binder, current_arg, reduce);
reduce.setSizeArg(current_arg++, M);
reduce.setSizeArg(current_arg++, N);
@@ -725,7 +721,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
execution_options_type const & options = ctr.execution_options();
if (M < p_.mL || N < p_.nL || K < p_.kL || ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1
if (ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1
|| (p_.simd_width>1 && (ldstartA % p_.simd_width > 0 || ldstartB % p_.simd_width > 0 || pA->ld()%p_.simd_width > 0 || pB->ld()%p_.simd_width > 0)))
{
fallback.enqueue_block(queue, M, N, K, create_slice(*pA, 0, M, 0, K, swap_A), create_slice(*pB, 0, K, 0, N, swap_B),
@@ -733,22 +729,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return;
}
int_t lM = M / p_.mL * p_.mL;
int_t lN = N / p_.nL * p_.nL;
int_t lK = K / (p_.kL*p_.depth) * p_.kL*p_.depth;
value_scalar _1(1, dtype);
enqueue_block(queue, lM, lN, lK, create_slice(*pA, 0, lM, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, beta, program, suffix, options);
fallback.enqueue_block(queue, lM, lN, K - lK, create_slice(*pA, 0, lM, lK, K, swap_A), create_slice(*pB, lK, K, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, _1, program, "fallback", options);
fallback.enqueue_block(queue, lM, N - lN, lK, create_slice(*pA, 0, lM, 0, lK, swap_A), create_slice(*pB, 0, lK, lN, N, swap_B), create_slice(*pC, 0, lM, lN, N, false), alpha, beta, program, "fallback", options);
fallback.enqueue_block(queue, lM, N - lN, K - lK, create_slice(*pA, 0, lM, lK, K, swap_A), create_slice(*pB, lK, K, lN, N, swap_B), create_slice(*pC, 0, lM, lN, N, false), alpha, _1, program, "fallback", options);
fallback.enqueue_block(queue, M - lM, lN, lK, create_slice(*pA, lM, M, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, lN, swap_B), create_slice(*pC, lM, M, 0, lN, false), alpha, beta, program, "fallback", options);
fallback.enqueue_block(queue, M - lM, lN, K - lK, create_slice(*pA, lM, M, lK, K, swap_A), create_slice(*pB, lK, K, 0, lN, swap_B), create_slice(*pC, lM, M, 0, lN, false), alpha, _1, program, "fallback", options);
fallback.enqueue_block(queue, M - lM, N - lN, lK, create_slice(*pA, lM, M, 0, lK, swap_A), create_slice(*pB, 0, lK, lN, N, swap_B), create_slice(*pC, lM, M, lN, N, false), alpha, beta, program, "fallback", options);
fallback.enqueue_block(queue, M - lM, N - lN, K - lK, create_slice(*pA, lM, M, lK, K, swap_A), create_slice(*pB, lK, K, lN, N, swap_B), create_slice(*pC, lM, M, lN, N, false), alpha, _1, program, "fallback", options);
enqueue_block(queue, M, N, K, create_slice(*pA, 0, M, 0, K, swap_A), create_slice(*pB, 0, K, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, beta, program, suffix, options);
}
//

View File

@@ -151,12 +151,12 @@ namespace detail
if(name=="vaxpy") return VECTOR_AXPY_TYPE;
if(name=="dot") return REDUCTION_TYPE;
if(name=="maxpy") return MATRIX_AXPY_TYPE;
if(name=="gemvN") return ROW_WISE_REDUCTION_TYPE;
if(name=="gemvT") return COL_WISE_REDUCTION_TYPE;
if(name=="gemmNN") return MATRIX_PRODUCT_NN_TYPE;
if(name=="gemmNT") return MATRIX_PRODUCT_NT_TYPE;
if(name=="gemmTN") return MATRIX_PRODUCT_TN_TYPE;
if(name=="gemmTT") return MATRIX_PRODUCT_TT_TYPE;
if(name=="mreduction_rows") return ROW_WISE_REDUCTION_TYPE;
if(name=="mreduction_cols") return COL_WISE_REDUCTION_TYPE;
if(name=="mproduct_nn") return MATRIX_PRODUCT_NN_TYPE;
if(name=="mproduct_nt") return MATRIX_PRODUCT_NT_TYPE;
if(name=="mproduct_tn") return MATRIX_PRODUCT_TN_TYPE;
if(name=="mproduct_tt") return MATRIX_PRODUCT_TT_TYPE;
throw std::invalid_argument("Invalid expression: " + name);
}
@@ -176,17 +176,17 @@ namespace detail
return tools::shared_ptr<base>(new reduction(a[0], a[1], a[2], fetch[a[3]]));
else if(template_name=="maxpy")
return tools::shared_ptr<base>(new maxpy(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemvN")!=std::string::npos)
else if(template_name.find("mreduction_rows")!=std::string::npos)
return tools::shared_ptr<base>(new mreduction_rows(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemvT")!=std::string::npos)
else if(template_name.find("mreduction_cols")!=std::string::npos)
return tools::shared_ptr<base>(new mreduction_cols(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemmNN")!=std::string::npos)
else if(template_name.find("mproduct_nn")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemmTN")!=std::string::npos)
else if(template_name.find("mproduct_tn")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemmNT")!=std::string::npos)
else if(template_name.find("mproduct_nt")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemmTT")!=std::string::npos)
else if(template_name.find("mproduct_tt")!=std::string::npos)
return tools::shared_ptr<base>(new mproduct_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else
throw std::invalid_argument("Invalid expression: " + template_name);
@@ -207,7 +207,7 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
str.assign((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
document.Parse<0>(str.c_str());
//Deserialize
std::vector<std::string> operations = {"vaxpy", "dot", "maxpy", "gemvN", "gemvT", "gemmNN", "gemmTN", "gemmNT", "gemmTT"};
std::vector<std::string> operations = {"vaxpy", "dot", "maxpy", "gemv_n", "gemv_t", "mproduct_nn", "mproduct_tn", "mproduct_nt", "mproduct_tt"};
std::vector<std::string> dtype = {"float32", "float64"};
for(auto & operation : operations)
{
@@ -256,10 +256,10 @@ std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<base> > ini
res[std::make_pair(MATRIX_AXPY_TYPE, DTYPE)] = ptr_t(new maxpy(1,8,8,8,8,FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(ROW_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new mreduction_rows(1, 8, 8, 4, 16, FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(COL_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new mreduction_cols(1, 8, 8, 64, 8, FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(MATRIX_PRODUCT_NN_TYPE, DTYPE)] = ptr_t(new mproduct_nn(1, 8, 8, 8, 1, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TN_TYPE, DTYPE)] = ptr_t(new mproduct_tn(1, 8, 8, 8, 1, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_NT_TYPE, DTYPE)] = ptr_t(new mproduct_nt(1, 8, 8, 8, 1, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TT_TYPE, DTYPE)] = ptr_t(new mproduct_tt(1, 8, 8, 8, 1, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_NN_TYPE, DTYPE)] = ptr_t(new mproduct_nn(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TN_TYPE, DTYPE)] = ptr_t(new mproduct_tn(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_NT_TYPE, DTYPE)] = ptr_t(new mproduct_nt(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TT_TYPE, DTYPE)] = ptr_t(new mproduct_tt(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
}
return res;
}

View File

@@ -115,7 +115,7 @@ def main():
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
#Source files
src = 'src/lib/array.cpp src/lib/wrap/clBLAS.cpp src/lib/value_scalar.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/expression.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/context.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/event.cpp src/lib/driver/device.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/backend/parse.cpp src/lib/backend/templates/reduction.cpp src/lib/backend/templates/mreduction.cpp src/lib/backend/templates/mproduct.cpp src/lib/backend/templates/maxpy.cpp src/lib/backend/templates/vaxpy.cpp src/lib/backend/templates/base.cpp src/lib/backend/stream.cpp src/lib/backend/mapped_object.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
src = 'src/lib/array.cpp src/lib/value_scalar.cpp src/lib/wrap/clBLAS.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/expression.cpp src/lib/symbolic/io.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/context.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/event.cpp src/lib/driver/device.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/backend/parse.cpp src/lib/backend/mapped_object.cpp src/lib/backend/templates/base.cpp src/lib/backend/templates/mproduct.cpp src/lib/backend/templates/vaxpy.cpp src/lib/backend/templates/mreduction.cpp src/lib/backend/templates/maxpy.cpp src/lib/backend/templates/reduction.cpp src/lib/backend/stream.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
boostsrc = 'external/boost/libs/'
for s in ['numpy','python','smart_ptr','system','thread']:
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]

View File

@@ -58,14 +58,14 @@ void test_impl(T epsilon, simple_matrix_base<T> & cC, simple_matrix_base<T> cons
cl_command_queue clqueue = (*queue.handle().cl)();
//Row-major
RUN_TEST("GEMM(ROW, N, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasNoTrans, clblasNoTrans, N, M, K, alpha, CHANDLE(B), OFF(B), LD(B),
CHANDLE(A), OFF(A), LD(A), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(ROW, N, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasTrans, clblasNoTrans, N, M, K, alpha, CHANDLE(BT), OFF(BT), LD(BT),
CHANDLE(A), OFF(A), LD(A), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(ROW, T, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasNoTrans, clblasTrans, N, M, K, alpha, CHANDLE(B), OFF(B), LD(B),
CHANDLE(AT), OFF(AT), LD(AT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(ROW, T, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasTrans, clblasTrans, N, M, K, alpha, CHANDLE(BT), OFF(BT), LD(BT),
CHANDLE(AT), OFF(AT), LD(AT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
// RUN_TEST("GEMM(ROW, N, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasNoTrans, clblasNoTrans, N, M, K, alpha, CHANDLE(B), OFF(B), LD(B),
// CHANDLE(A), OFF(A), LD(A), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
// RUN_TEST("GEMM(ROW, N, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasTrans, clblasNoTrans, N, M, K, alpha, CHANDLE(BT), OFF(BT), LD(BT),
// CHANDLE(A), OFF(A), LD(A), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
// RUN_TEST("GEMM(ROW, T, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasNoTrans, clblasTrans, N, M, K, alpha, CHANDLE(B), OFF(B), LD(B),
// CHANDLE(AT), OFF(AT), LD(AT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
// RUN_TEST("GEMM(ROW, T, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasTrans, clblasTrans, N, M, K, alpha, CHANDLE(BT), OFF(BT), LD(BT),
// CHANDLE(AT), OFF(AT), LD(AT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
//Column-major
RUN_TEST("GEMM(COL, N, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasNoTrans, clblasNoTrans, M, N, K, alpha, CHANDLE(A), OFF(A), LD(A),
@@ -94,14 +94,12 @@ void test_impl(T epsilon, simple_matrix_base<T> & cC, simple_matrix_base<T> cons
template<typename T>
void test_impl(T epsilon, ad::driver::Context const & ctx)
{
int_t M = 427;
int_t N = 248;
int_t K = 376;
int_t N = 243;
int_t K = 384;
int_t SUBM = 61;
int_t SUBN = 75;
int_t SUBM = 75;
int_t SUBN = 76;
int_t SUBK = 83;
{

View File

@@ -16,8 +16,8 @@ import tools
fetch_types = [isc.fetching_policy_type.FETCH_FROM_LOCAL,
isc.fetching_policy_type.FETCH_FROM_LOCAL,
isc.fetching_policy_type.FETCH_FROM_GLOBAL_CONTIGUOUS,
isc.fetching_policy_type.FETCH_FROM_GLOBAL_STRIDED]
isc.fetching_policy_type.FETCH_FROM_LOCAL,
isc.fetching_policy_type.FETCH_FROM_LOCAL]
def exhaustive(template, sizes, context):
tree, _ = tools.tree_of(template, sizes, context)

View File

@@ -22,12 +22,19 @@ def tune(device, operation, json_path):
context = isc.context(device)
#List of size tuples to use
sizes = list({isc.vaxpy: [(x,) for x in tools.expspace(1e3, 1e7, 4)],
isc.mreduction_cols: product(pow2range(4,17), pow2range(4,17)),
isc.mproduct_nt: product(pow2range(4, 17), pow2range(4, 17), pow2range(4, 17))}[operation])
sizes = unique(sizes)
sizes = {}
sizes[isc.vaxpy] = [(x,) for x in tools.expspace(1e3, 1e7, 4)]
sizes[isc.mreduction_rows] = product(pow2range(4,17), pow2range(4,17))
sizes[isc.mreduction_cols] = isc.mreduction_rows
sizes[isc.mproduct_nn] = product(pow2range(5, 10), pow2range(5, 10), pow2range(5, 10))
sizes[isc.mproduct_nn] = [(169, 128, 1728)]
sizes[isc.mproduct_tn] = sizes[isc.mproduct_nn]
sizes[isc.mproduct_nt] = sizes[isc.mproduct_nn]
sizes[isc.mproduct_tt] = sizes[isc.mproduct_nn]
sizes = unique(list(sizes[operation]))
sizes = [x for x in sizes if 1e-4 <= tools.memory_footprint(operation, x) <= 1e-1]
#Training data
performance = tools.metric_of(operation)
profiles = []
@@ -73,13 +80,10 @@ def tune(device, operation, json_path):
X.append(x)
Y.append(y)
#Build model
clf, nrmse = model.train(X, Y, profiles)
print 'The optimal classifer has NRMSE = %.2g (%d estimators and the max depth is %d'%(nrmse, clf.n_estimators, clf.max_depth)
#Export to JSON
if os.path.isfile(json_path):
json_data = json.load(open(args.out, 'r'))
json_data = json.load(open(json_path, 'r'))
else:
json_data = {}
json_data["version"] = "1.0"
@@ -89,6 +93,7 @@ def tune(device, operation, json_path):
json_data[operation_name]['float32'] = {}
D = json_data[operation_name]['float32']
if len(profiles) > 1:
clf, nrmse = model.train(X, Y, profiles)
D['predictor'] = [{'children_left': e.tree_.children_left.tolist(),
'children_right': e.tree_.children_right.tolist(),
'threshold': e.tree_.threshold.astype('float64').tolist(),
@@ -113,14 +118,14 @@ def parse_arguments():
print("Devices available:")
print("----------------")
for (i, d) in enumerate(devices):
selected = '[' + ('x' if device==d else '') + ']'
selected = '[' + ('x' if device==d else ' ') + ']'
print selected , '-', isc.device_type_to_string(d.type), '-', d.name, 'on', d.platform.name
print("----------------")
operation = {'vaxpy': isc.vaxpy, 'dot': isc.reduction,
'maxpy': isc.maxpy, 'gemv_n': isc.mreduction_rows, 'gemv_t': isc.mreduction_cols,
'gemm_nn': isc.mproduct_nn, 'gemv_tn': isc.mproduct_tn, 'gemm_nt': isc.mproduct_nt, 'gemm_tt':isc.mproduct_tt}[args.operation]
'gemm_nn': isc.mproduct_nn, 'gemm_tn': isc.mproduct_tn, 'gemm_nt': isc.mproduct_nt, 'gemm_tt':isc.mproduct_tt}[args.operation]
if not args.json:
json = tools.sanitize(device.name) + '.json'
return (device, operation, json)