C++/clBLAS: Bugfix in GEMM

This commit is contained in:
Philippe
2015-06-27 13:53:31 -04:00
committed by Philippe Tillet
parent 743a559f76
commit 8f19d2a69c
5 changed files with 17 additions and 13 deletions

View File

@@ -677,8 +677,7 @@ namespace detail
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
if(A_trans) shape[0] = A.shape()[1];
if(B_trans) shape[1] = B.shape()[0];
if(A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_TT_TYPE;
else if(A_trans && !B_trans) type = OPERATOR_MATRIX_PRODUCT_TN_TYPE;
else if(!A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_NT_TYPE;

View File

@@ -202,8 +202,9 @@ extern "C"
bool AeffTrans = (transA==clblasTrans) ^ (order==clblasRowMajor);\
bool BeffTrans = (transB==clblasTrans) ^ (order==clblasRowMajor);\
/*Operation*/\
if(AeffTrans && BeffTrans)\
if(AeffTrans && BeffTrans){\
execute(is::detail::assign(C, alpha*dot(A.T(), B.T()) + beta*C), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);\
}\
else if(AeffTrans && !BeffTrans)\
execute(is::detail::assign(C, alpha*dot(A.T(), B) + beta*C), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);\
else if(!AeffTrans && BeffTrans)\
@@ -214,6 +215,6 @@ extern "C"
}
MAKE_GEMM(S, is::FLOAT_TYPE, cl_float)
MAKE_GEMM(D, is::FLOAT_TYPE, cl_double)
MAKE_GEMM(D, is::DOUBLE_TYPE, cl_double)
}

View File

@@ -71,7 +71,7 @@ def main():
#Includes
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
#Sources
src = 'src/lib/array.cpp src/lib/wrap/clBLAS.cpp src/lib/value_scalar.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/command_queue.cpp src/lib/driver/backend.cpp src/lib/driver/program.cpp src/lib/driver/device.cpp src/lib/driver/buffer.cpp src/lib/driver/event.cpp src/lib/driver/context.cpp src/lib/driver/check.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/backend/parse.cpp src/lib/backend/templates/reduction.cpp src/lib/backend/templates/mreduction.cpp src/lib/backend/templates/mproduct.cpp src/lib/backend/templates/maxpy.cpp src/lib/backend/templates/vaxpy.cpp src/lib/backend/templates/base.cpp src/lib/backend/stream.cpp src/lib/backend/mapped_object.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
src = 'src/lib/wrap/clBLAS.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/event.cpp src/lib/driver/device.cpp src/lib/driver/context.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/backend/parse.cpp src/lib/backend/templates/mproduct.cpp src/lib/backend/templates/vaxpy.cpp src/lib/backend/templates/reduction.cpp src/lib/backend/templates/mreduction.cpp src/lib/backend/templates/maxpy.cpp src/lib/backend/templates/base.cpp src/lib/backend/stream.cpp src/lib/backend/mapped_object.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
boostsrc = 'external/boost/libs/'
for s in ['numpy','python','smart_ptr','system','thread']:
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]

View File

@@ -13,6 +13,9 @@ template<> struct BLAS<double> { template<class FT, class DT> static DT F(FT , D
enum interface_t{clBLAS, CPP};
#define CHANDLE(X) (*X.data().handle().cl)()
#define OFF(X) X.start()[0] + X.start()[1]*X.ld()
#define LD(X) X.ld()*X.stride()[1]
/*------ Simple Vector ---------*/
template<class T>
class simple_vector_base

View File

@@ -56,14 +56,14 @@ void test_impl(T epsilon, simple_matrix_base<T> & cC, simple_matrix_base<T> cons
if(interface==clBLAS)
{
cl_command_queue clqueue = (*queue.handle().cl)();
ad::int_t offa = A.start()[0] + A.start()[1]*A.ld();
ad::int_t lda = A.ld()*A.stride()[1];
ad::int_t offb = B.start()[0] + B.start()[1]*A.ld();
ad::int_t ldb = B.ld()*B.stride()[1];
ad::int_t offc = C.start()[0] + C.start()[1]*A.ld();
ad::int_t ldc = C.ld()*C.stride()[1];
RUN_TEST("GEMM(COL, N, N)", clblasSgemm(clblasColumnMajor, clblasNoTrans, clblasNoTrans, M, N, K, alpha, CHANDLE(A), offa, lda, CHANDLE(B), offb, ldb, beta, CHANDLE(C), offc, ldc,
1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(N, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasNoTrans, clblasNoTrans, M, N, K, alpha, CHANDLE(A), OFF(A), LD(A),
CHANDLE(B), OFF(B), LD(B), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(N, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasNoTrans, clblasTrans, M, N, K, alpha, CHANDLE(A), OFF(A), LD(A),
CHANDLE(BT), OFF(BT), LD(BT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(T, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasTrans, clblasNoTrans, M, N, K, alpha, CHANDLE(AT), OFF(AT), LD(AT),
CHANDLE(B), OFF(B), LD(B), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
RUN_TEST("GEMM(T, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasTrans, clblasTrans, M, N, K, alpha, CHANDLE(AT), OFF(AT), LD(AT),
CHANDLE(BT), OFF(BT), LD(BT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
}
else
{
@@ -93,6 +93,7 @@ void test_impl(T epsilon, ad::driver::Context const & ctx)
INIT_MATRIX(M, SUBM, 8, 2, K, SUBK, 4, 3, cA, A, ctx);
INIT_MATRIX(K, SUBK, 9, 4, N, SUBN, 6, 2, cB, B, ctx);
std::cout << "full..." << std::endl;
test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, clBLAS);
test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, CPP);
std::cout << "slice..." << std::endl;
test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, CPP);