Kernels: Fixed various corner cases for the kernel templates and BLAS

This commit is contained in:
Philippe Tillet
2015-11-25 18:42:25 -05:00
parent 6be5929b0d
commit 6fc94c0c0b
15 changed files with 107 additions and 38 deletions

View File

@@ -903,6 +903,8 @@ math_expression dot(LTYPE const & x, RTYPE const & y)\
{\
numeric_type dtype = x.dtype();\
driver::Context const & context = x.context();\
if(x.shape().max()==1 || y.shape().max()==1)\
return x*y;\
if(x.dim()==2 && x.shape()[1]==0)\
return zeros(x.shape()[0], y.shape()[1], dtype, context);\
if(x.shape()[0]==0 || (y.dim()==2 && y.shape()[1]==0))\
@@ -927,10 +929,12 @@ math_expression dot(LTYPE const & x, RTYPE const & y)\
else\
return trans(detail::matvecprod(trans(y), trans(x)));\
}\
if(x.shape()[0]==1)\
if(x.shape()[0]==1 && y.shape()[1]==1)\
return sum(x*trans(y));\
if(x.shape()[1]==1)\
return outer(x, y);\
if(x.shape()[0]==1 && y.shape()[1]==2)\
return trans(detail::matvecprod(trans(y), trans(x)));\
if(x.shape()[1]==1 && y.shape()[0]==1)\
return x*y;\
else /*if(x.dim()==2 && y.dim()==2)*/\
return detail::matmatprod(x, y);\
}
@@ -995,7 +999,7 @@ void copy(void const * data, array_base& x, driver::CommandQueue & queue, bool b
void copy(array_base const & x, void* data, driver::CommandQueue & queue, bool blocking)
{
unsigned int dtypesize = size_of(x.dtype());
if(x.start()==0 && x.shape()[0]*x.stride().prod()==x.shape().prod())
if(x.start()==0 && x.stride().prod()==x.shape().prod())
{
queue.read(x.data(), blocking, 0, x.shape().prod()*dtypesize, data);
}