C++/clBLAS: Bugfix in GEMM
This commit is contained in:
committed by
Philippe Tillet
parent
743a559f76
commit
8f19d2a69c
@@ -677,8 +677,7 @@ namespace detail
|
|||||||
|
|
||||||
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
|
bool A_trans = A_root.op.type==OPERATOR_TRANS_TYPE;
|
||||||
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
|
bool B_trans = B_root.op.type==OPERATOR_TRANS_TYPE;
|
||||||
if(A_trans) shape[0] = A.shape()[1];
|
|
||||||
if(B_trans) shape[1] = B.shape()[0];
|
|
||||||
if(A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_TT_TYPE;
|
if(A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_TT_TYPE;
|
||||||
else if(A_trans && !B_trans) type = OPERATOR_MATRIX_PRODUCT_TN_TYPE;
|
else if(A_trans && !B_trans) type = OPERATOR_MATRIX_PRODUCT_TN_TYPE;
|
||||||
else if(!A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_NT_TYPE;
|
else if(!A_trans && B_trans) type = OPERATOR_MATRIX_PRODUCT_NT_TYPE;
|
||||||
|
@@ -202,8 +202,9 @@ extern "C"
|
|||||||
bool AeffTrans = (transA==clblasTrans) ^ (order==clblasRowMajor);\
|
bool AeffTrans = (transA==clblasTrans) ^ (order==clblasRowMajor);\
|
||||||
bool BeffTrans = (transB==clblasTrans) ^ (order==clblasRowMajor);\
|
bool BeffTrans = (transB==clblasTrans) ^ (order==clblasRowMajor);\
|
||||||
/*Operation*/\
|
/*Operation*/\
|
||||||
if(AeffTrans && BeffTrans)\
|
if(AeffTrans && BeffTrans){\
|
||||||
execute(is::detail::assign(C, alpha*dot(A.T(), B.T()) + beta*C), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);\
|
execute(is::detail::assign(C, alpha*dot(A.T(), B.T()) + beta*C), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);\
|
||||||
|
}\
|
||||||
else if(AeffTrans && !BeffTrans)\
|
else if(AeffTrans && !BeffTrans)\
|
||||||
execute(is::detail::assign(C, alpha*dot(A.T(), B) + beta*C), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);\
|
execute(is::detail::assign(C, alpha*dot(A.T(), B) + beta*C), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);\
|
||||||
else if(!AeffTrans && BeffTrans)\
|
else if(!AeffTrans && BeffTrans)\
|
||||||
@@ -214,6 +215,6 @@ extern "C"
|
|||||||
}
|
}
|
||||||
|
|
||||||
MAKE_GEMM(S, is::FLOAT_TYPE, cl_float)
|
MAKE_GEMM(S, is::FLOAT_TYPE, cl_float)
|
||||||
MAKE_GEMM(D, is::FLOAT_TYPE, cl_double)
|
MAKE_GEMM(D, is::DOUBLE_TYPE, cl_double)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -71,7 +71,7 @@ def main():
|
|||||||
#Includes
|
#Includes
|
||||||
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
|
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
|
||||||
#Sources
|
#Sources
|
||||||
src = 'src/lib/array.cpp src/lib/wrap/clBLAS.cpp src/lib/value_scalar.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/command_queue.cpp src/lib/driver/backend.cpp src/lib/driver/program.cpp src/lib/driver/device.cpp src/lib/driver/buffer.cpp src/lib/driver/event.cpp src/lib/driver/context.cpp src/lib/driver/check.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/backend/parse.cpp src/lib/backend/templates/reduction.cpp src/lib/backend/templates/mreduction.cpp src/lib/backend/templates/mproduct.cpp src/lib/backend/templates/maxpy.cpp src/lib/backend/templates/vaxpy.cpp src/lib/backend/templates/base.cpp src/lib/backend/stream.cpp src/lib/backend/mapped_object.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
|
src = 'src/lib/wrap/clBLAS.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/event.cpp src/lib/driver/device.cpp src/lib/driver/context.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/backend/parse.cpp src/lib/backend/templates/mproduct.cpp src/lib/backend/templates/vaxpy.cpp src/lib/backend/templates/reduction.cpp src/lib/backend/templates/mreduction.cpp src/lib/backend/templates/maxpy.cpp src/lib/backend/templates/base.cpp src/lib/backend/stream.cpp src/lib/backend/mapped_object.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
|
||||||
boostsrc = 'external/boost/libs/'
|
boostsrc = 'external/boost/libs/'
|
||||||
for s in ['numpy','python','smart_ptr','system','thread']:
|
for s in ['numpy','python','smart_ptr','system','thread']:
|
||||||
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]
|
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]
|
||||||
|
@@ -13,6 +13,9 @@ template<> struct BLAS<double> { template<class FT, class DT> static DT F(FT , D
|
|||||||
enum interface_t{clBLAS, CPP};
|
enum interface_t{clBLAS, CPP};
|
||||||
|
|
||||||
#define CHANDLE(X) (*X.data().handle().cl)()
|
#define CHANDLE(X) (*X.data().handle().cl)()
|
||||||
|
#define OFF(X) X.start()[0] + X.start()[1]*X.ld()
|
||||||
|
#define LD(X) X.ld()*X.stride()[1]
|
||||||
|
|
||||||
/*------ Simple Vector ---------*/
|
/*------ Simple Vector ---------*/
|
||||||
template<class T>
|
template<class T>
|
||||||
class simple_vector_base
|
class simple_vector_base
|
||||||
|
@@ -56,14 +56,14 @@ void test_impl(T epsilon, simple_matrix_base<T> & cC, simple_matrix_base<T> cons
|
|||||||
if(interface==clBLAS)
|
if(interface==clBLAS)
|
||||||
{
|
{
|
||||||
cl_command_queue clqueue = (*queue.handle().cl)();
|
cl_command_queue clqueue = (*queue.handle().cl)();
|
||||||
ad::int_t offa = A.start()[0] + A.start()[1]*A.ld();
|
RUN_TEST("GEMM(N, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasNoTrans, clblasNoTrans, M, N, K, alpha, CHANDLE(A), OFF(A), LD(A),
|
||||||
ad::int_t lda = A.ld()*A.stride()[1];
|
CHANDLE(B), OFF(B), LD(B), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
|
||||||
ad::int_t offb = B.start()[0] + B.start()[1]*A.ld();
|
RUN_TEST("GEMM(N, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasNoTrans, clblasTrans, M, N, K, alpha, CHANDLE(A), OFF(A), LD(A),
|
||||||
ad::int_t ldb = B.ld()*B.stride()[1];
|
CHANDLE(BT), OFF(BT), LD(BT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
|
||||||
ad::int_t offc = C.start()[0] + C.start()[1]*A.ld();
|
RUN_TEST("GEMM(T, N)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasTrans, clblasNoTrans, M, N, K, alpha, CHANDLE(AT), OFF(AT), LD(AT),
|
||||||
ad::int_t ldc = C.ld()*C.stride()[1];
|
CHANDLE(B), OFF(B), LD(B), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
|
||||||
RUN_TEST("GEMM(COL, N, N)", clblasSgemm(clblasColumnMajor, clblasNoTrans, clblasNoTrans, M, N, K, alpha, CHANDLE(A), offa, lda, CHANDLE(B), offb, ldb, beta, CHANDLE(C), offc, ldc,
|
RUN_TEST("GEMM(T, T)", BLAS<T>::F(clblasSgemm,clblasDgemm)(clblasColumnMajor, clblasTrans, clblasTrans, M, N, K, alpha, CHANDLE(AT), OFF(AT), LD(AT),
|
||||||
1, &clqueue, 0, NULL, NULL));
|
CHANDLE(BT), OFF(BT), LD(BT), beta, CHANDLE(C), OFF(C), LD(C), 1, &clqueue, 0, NULL, NULL));
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@@ -93,6 +93,7 @@ void test_impl(T epsilon, ad::driver::Context const & ctx)
|
|||||||
INIT_MATRIX(M, SUBM, 8, 2, K, SUBK, 4, 3, cA, A, ctx);
|
INIT_MATRIX(M, SUBM, 8, 2, K, SUBK, 4, 3, cA, A, ctx);
|
||||||
INIT_MATRIX(K, SUBK, 9, 4, N, SUBN, 6, 2, cB, B, ctx);
|
INIT_MATRIX(K, SUBK, 9, 4, N, SUBN, 6, 2, cB, B, ctx);
|
||||||
std::cout << "full..." << std::endl;
|
std::cout << "full..." << std::endl;
|
||||||
|
test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, clBLAS);
|
||||||
test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, CPP);
|
test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, CPP);
|
||||||
std::cout << "slice..." << std::endl;
|
std::cout << "slice..." << std::endl;
|
||||||
test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, CPP);
|
test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, CPP);
|
||||||
|
Reference in New Issue
Block a user