Kernels: Fixed various corner cases for the kernel templates and BLAS
This commit is contained in:
@@ -903,6 +903,8 @@ math_expression dot(LTYPE const & x, RTYPE const & y)\
|
||||
{\
|
||||
numeric_type dtype = x.dtype();\
|
||||
driver::Context const & context = x.context();\
|
||||
if(x.shape().max()==1 || y.shape().max()==1)\
|
||||
return x*y;\
|
||||
if(x.dim()==2 && x.shape()[1]==0)\
|
||||
return zeros(x.shape()[0], y.shape()[1], dtype, context);\
|
||||
if(x.shape()[0]==0 || (y.dim()==2 && y.shape()[1]==0))\
|
||||
@@ -927,10 +929,12 @@ math_expression dot(LTYPE const & x, RTYPE const & y)\
|
||||
else\
|
||||
return trans(detail::matvecprod(trans(y), trans(x)));\
|
||||
}\
|
||||
if(x.shape()[0]==1)\
|
||||
if(x.shape()[0]==1 && y.shape()[1]==1)\
|
||||
return sum(x*trans(y));\
|
||||
if(x.shape()[1]==1)\
|
||||
return outer(x, y);\
|
||||
if(x.shape()[0]==1 && y.shape()[1]==2)\
|
||||
return trans(detail::matvecprod(trans(y), trans(x)));\
|
||||
if(x.shape()[1]==1 && y.shape()[0]==1)\
|
||||
return x*y;\
|
||||
else /*if(x.dim()==2 && y.dim()==2)*/\
|
||||
return detail::matmatprod(x, y);\
|
||||
}
|
||||
@@ -995,7 +999,7 @@ void copy(void const * data, array_base& x, driver::CommandQueue & queue, bool b
|
||||
void copy(array_base const & x, void* data, driver::CommandQueue & queue, bool blocking)
|
||||
{
|
||||
unsigned int dtypesize = size_of(x.dtype());
|
||||
if(x.start()==0 && x.shape()[0]*x.stride().prod()==x.shape().prod())
|
||||
if(x.start()==0 && x.stride().prod()==x.shape().prod())
|
||||
{
|
||||
queue.read(x.data(), blocking, 0, x.shape().prod()*dtypesize, data);
|
||||
}
|
||||
|
Reference in New Issue
Block a user