Tests: Now using prime-numbered sizes for GEMM

This commit is contained in:
Philippe Tillet
2015-07-11 11:29:28 -04:00
parent cfa6ea812d
commit 84e47b871b
2 changed files with 21 additions and 21 deletions

View File

@@ -115,7 +115,7 @@ def main():
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
#Source files
src = 'src/lib/array.cpp src/lib/value_scalar.cpp src/lib/wrap/clBLAS.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/expression.cpp src/lib/symbolic/io.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/context.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/event.cpp src/lib/driver/device.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/backend/parse.cpp src/lib/backend/mapped_object.cpp src/lib/backend/templates/axpy.cpp src/lib/backend/templates/ger.cpp src/lib/backend/templates/gemm.cpp src/lib/backend/templates/dot.cpp src/lib/backend/templates/gemv.cpp src/lib/backend/templates/base.cpp src/lib/backend/stream.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
src = 'src/lib/array.cpp src/lib/value_scalar.cpp src/lib/wrap/clBLAS.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/expression.cpp src/lib/symbolic/io.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/context.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/event.cpp src/lib/driver/device.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/backend/parse.cpp src/lib/backend/mapped_object.cpp src/lib/backend/templates/ger.cpp src/lib/backend/templates/gemv.cpp src/lib/backend/templates/gemm.cpp src/lib/backend/templates/dot.cpp src/lib/backend/templates/base.cpp src/lib/backend/templates/axpy.cpp src/lib/backend/stream.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
boostsrc = 'external/boost/libs/'
for s in ['numpy','python','smart_ptr','system','thread']:
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]

View File

@@ -58,26 +58,26 @@ void test_impl(T epsilon, simple_matrix_base<T> & cC, simple_matrix_base<T> cons
cl_command_queue clqueue = (*queue.handle().cl)();
//Row-major
// RUN_TEST("GEMM(ROW, N, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasNoTrans, clblasNoTrans, N, M, K, alpha, CHANDLE(B), OFF(B), LD(B),
// CHANDLE(A), OFF(A), LD(A), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
// RUN_TEST("GEMM(ROW, N, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasTrans, clblasNoTrans, N, M, K, alpha, CHANDLE(BT), OFF(BT), LD(BT),
// CHANDLE(A), OFF(A), LD(A), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
// RUN_TEST("GEMM(ROW, T, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasNoTrans, clblasTrans, N, M, K, alpha, CHANDLE(B), OFF(B), LD(B),
// CHANDLE(AT), OFF(AT), LD(AT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
// RUN_TEST("GEMM(ROW, T, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasTrans, clblasTrans, N, M, K, alpha, CHANDLE(BT), OFF(BT), LD(BT),
// CHANDLE(AT), OFF(AT), LD(AT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(ROW, N, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasNoTrans, clblasNoTrans, N, M, K, alpha, CHANDLE(B), OFF(B), LD(B),
CHANDLE(A), OFF(A), LD(A), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(ROW, N, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasTrans, clblasNoTrans, N, M, K, alpha, CHANDLE(BT), OFF(BT), LD(BT),
CHANDLE(A), OFF(A), LD(A), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(ROW, T, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasNoTrans, clblasTrans, N, M, K, alpha, CHANDLE(B), OFF(B), LD(B),
CHANDLE(AT), OFF(AT), LD(AT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(ROW, T, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasRowMajor, clblasTrans, clblasTrans, N, M, K, alpha, CHANDLE(BT), OFF(BT), LD(BT),
CHANDLE(AT), OFF(AT), LD(AT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
//Column-major
// RUN_TEST("GEMM(COL, N, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasNoTrans, clblasNoTrans, M, N, K, alpha, CHANDLE(A), OFF(A), LD(A),
// CHANDLE(B), OFF(B), LD(B), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(COL, N, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasNoTrans, clblasNoTrans, M, N, K, alpha, CHANDLE(A), OFF(A), LD(A),
CHANDLE(B), OFF(B), LD(B), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(COL, N, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasNoTrans, clblasTrans, M, N, K, alpha, CHANDLE(A), OFF(A), LD(A),
CHANDLE(BT), OFF(BT), LD(BT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
// RUN_TEST("GEMM(COL, T, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasTrans, clblasNoTrans, M, N, K, alpha, CHANDLE(AT), OFF(AT), LD(AT),
// CHANDLE(B), OFF(B), LD(B), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
// RUN_TEST("GEMM(COL, T, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasTrans, clblasTrans, M, N, K, alpha, CHANDLE(AT), OFF(AT), LD(AT),
// CHANDLE(BT), OFF(BT), LD(BT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(COL, T, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasTrans, clblasNoTrans, M, N, K, alpha, CHANDLE(AT), OFF(AT), LD(AT),
CHANDLE(B), OFF(B), LD(B), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(COL, T, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasTrans, clblasTrans, M, N, K, alpha, CHANDLE(AT), OFF(AT), LD(AT),
CHANDLE(BT), OFF(BT), LD(BT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
@@ -97,9 +97,9 @@ void test_impl(T epsilon, simple_matrix_base<T> & cC, simple_matrix_base<T> cons
template<typename T>
void test_impl(T epsilon, ad::driver::Context const & ctx)
{
int_t M = 64;
int_t N = 64;
int_t K = 64;
int_t M = 173;
int_t N = 233;
int_t K = 293;
int_t SUBM = 64;
int_t SUBN = 64;
@@ -110,15 +110,15 @@ void test_impl(T epsilon, ad::driver::Context const & ctx)
INIT_MATRIX(M, SUBM, 8, 1, K, SUBK, 4, 1, cA, A, ctx);
INIT_MATRIX(K, SUBK, 9, 1, N, SUBN, 6, 1, cB, B, ctx);
test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, clBLAS, "BLAS, FULL");
// test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, clBLAS, "BLAS, SUB");
test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, clBLAS, "BLAS, SUB");
}
{
INIT_MATRIX(M, SUBM, 5, 2, N, SUBN, 7, 3, cC, C, ctx);
INIT_MATRIX(M, SUBM, 8, 2, K, SUBK, 4, 3, cA, A, ctx);
INIT_MATRIX(K, SUBK, 9, 4, N, SUBN, 6, 2, cB, B, ctx);
// test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, CPP, "C++, FULL");
// test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, CPP, "C++, SUB");
test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, CPP, "C++, FULL");
test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, CPP, "C++, SUB");
}
}