Backend: A lot of bugfixes in dot() for handling shapes better

This commit is contained in:
Philippe Tillet
2015-06-30 17:55:57 -04:00
parent e7cabf65ac
commit cf2dba43ef
12 changed files with 108 additions and 73 deletions

View File

@@ -59,7 +59,7 @@ extern "C"
clRetainMemObject(mx); \
is::array y(N, TYPE_ISAAC, cl::Buffer(my), offy, incy); \
clRetainMemObject(my); \
execute(is::assign(y, x + alpha*y), y.context(), numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events); \
execute(is::assign(y, alpha*x + y), y.context(), numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events); \
return clblasSuccess; \
}
@@ -157,15 +157,14 @@ extern "C"
std::swap(M, N);\
transA = (transA==clblasTrans)?clblasNoTrans:clblasTrans;\
}\
is::int_t As1 = M, As2 = N;\
if(transA==clblasTrans) std::swap(As1, As2);\
is::array A(As1, As2, TYPE_ISAAC, cl::Buffer(mA), offA, lda);\
is::array A(M, N, TYPE_ISAAC, cl::Buffer(mA), offA, lda);\
clRetainMemObject(mA);\
\
is::array x(N, TYPE_ISAAC, cl::Buffer(mx), offx, incx);\
is::int_t sx = N, sy = M;\
if(transA) std::swap(sx, sy);\
is::array x(sx, TYPE_ISAAC, cl::Buffer(mx), offx, incx);\
clRetainMemObject(mx);\
\
is::array y(M, TYPE_ISAAC, cl::Buffer(my), offy, incy);\
is::array y(sy, TYPE_ISAAC, cl::Buffer(my), offy, incy);\
clRetainMemObject(my);\
\
is::driver::Context const & context = A.context();\
@@ -182,6 +181,7 @@ extern "C"
//*****************
//BLAS3
//*****************
#define MAKE_GEMM(TYPE_CHAR, TYPE_ISAAC, TYPE_CL) \
clblasStatus clblas ## TYPE_CHAR ## gemm(clblasOrder order, clblasTranspose transA, clblasTranspose transB,\
size_t M, size_t N, size_t K,\
@@ -198,8 +198,7 @@ extern "C"
std::swap(offA, offB);\
std::swap(lda, ldb);\
std::swap(M, N);\
transA = (transA==clblasTrans)?clblasNoTrans:clblasTrans;\
transB = (transB==clblasTrans)?clblasNoTrans:clblasTrans;\
std::swap(transA, transB);\
}\
is::int_t As1 = M, As2 = K;\
is::int_t Bs1 = K, Bs2 = N;\
@@ -214,9 +213,8 @@ extern "C"
clRetainMemObject(mC);\
is::driver::Context const & context = C.context();\
/*Operation*/\
if((transA==clblasTrans) && (transB==clblasTrans)){\
if((transA==clblasTrans) && (transB==clblasTrans))\
execute(is::assign(C, alpha*dot(A.T(), B.T()) + beta*C), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);\
}\
else if((transA==clblasTrans) && (transB==clblasNoTrans))\
execute(is::assign(C, alpha*dot(A.T(), B) + beta*C), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);\
else if((transA==clblasNoTrans) && (transB==clblasTrans))\
@@ -229,4 +227,6 @@ extern "C"
MAKE_GEMM(S, is::FLOAT_TYPE, cl_float)
MAKE_GEMM(D, is::DOUBLE_TYPE, cl_double)
#undef DOT
}