Backend: A lot of bugfixes in dot() for handling shapes better
This commit is contained in:
@@ -59,7 +59,7 @@ extern "C"
|
||||
clRetainMemObject(mx); \
|
||||
is::array y(N, TYPE_ISAAC, cl::Buffer(my), offy, incy); \
|
||||
clRetainMemObject(my); \
|
||||
execute(is::assign(y, x + alpha*y), y.context(), numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events); \
|
||||
execute(is::assign(y, alpha*x + y), y.context(), numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events); \
|
||||
return clblasSuccess; \
|
||||
}
|
||||
|
||||
@@ -157,15 +157,14 @@ extern "C"
|
||||
std::swap(M, N);\
|
||||
transA = (transA==clblasTrans)?clblasNoTrans:clblasTrans;\
|
||||
}\
|
||||
is::int_t As1 = M, As2 = N;\
|
||||
if(transA==clblasTrans) std::swap(As1, As2);\
|
||||
is::array A(As1, As2, TYPE_ISAAC, cl::Buffer(mA), offA, lda);\
|
||||
is::array A(M, N, TYPE_ISAAC, cl::Buffer(mA), offA, lda);\
|
||||
clRetainMemObject(mA);\
|
||||
\
|
||||
is::array x(N, TYPE_ISAAC, cl::Buffer(mx), offx, incx);\
|
||||
is::int_t sx = N, sy = M;\
|
||||
if(transA) std::swap(sx, sy);\
|
||||
is::array x(sx, TYPE_ISAAC, cl::Buffer(mx), offx, incx);\
|
||||
clRetainMemObject(mx);\
|
||||
\
|
||||
is::array y(M, TYPE_ISAAC, cl::Buffer(my), offy, incy);\
|
||||
is::array y(sy, TYPE_ISAAC, cl::Buffer(my), offy, incy);\
|
||||
clRetainMemObject(my);\
|
||||
\
|
||||
is::driver::Context const & context = A.context();\
|
||||
@@ -182,6 +181,7 @@ extern "C"
|
||||
//*****************
|
||||
//BLAS3
|
||||
//*****************
|
||||
|
||||
#define MAKE_GEMM(TYPE_CHAR, TYPE_ISAAC, TYPE_CL) \
|
||||
clblasStatus clblas ## TYPE_CHAR ## gemm(clblasOrder order, clblasTranspose transA, clblasTranspose transB,\
|
||||
size_t M, size_t N, size_t K,\
|
||||
@@ -198,8 +198,7 @@ extern "C"
|
||||
std::swap(offA, offB);\
|
||||
std::swap(lda, ldb);\
|
||||
std::swap(M, N);\
|
||||
transA = (transA==clblasTrans)?clblasNoTrans:clblasTrans;\
|
||||
transB = (transB==clblasTrans)?clblasNoTrans:clblasTrans;\
|
||||
std::swap(transA, transB);\
|
||||
}\
|
||||
is::int_t As1 = M, As2 = K;\
|
||||
is::int_t Bs1 = K, Bs2 = N;\
|
||||
@@ -214,9 +213,8 @@ extern "C"
|
||||
clRetainMemObject(mC);\
|
||||
is::driver::Context const & context = C.context();\
|
||||
/*Operation*/\
|
||||
if((transA==clblasTrans) && (transB==clblasTrans)){\
|
||||
if((transA==clblasTrans) && (transB==clblasTrans))\
|
||||
execute(is::assign(C, alpha*dot(A.T(), B.T()) + beta*C), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);\
|
||||
}\
|
||||
else if((transA==clblasTrans) && (transB==clblasNoTrans))\
|
||||
execute(is::assign(C, alpha*dot(A.T(), B) + beta*C), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);\
|
||||
else if((transA==clblasNoTrans) && (transB==clblasTrans))\
|
||||
@@ -229,4 +227,6 @@ extern "C"
|
||||
MAKE_GEMM(S, is::FLOAT_TYPE, cl_float)
|
||||
MAKE_GEMM(D, is::DOUBLE_TYPE, cl_double)
|
||||
|
||||
#undef DOT
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user