API: Fixes more issues in dot() corner cases
This commit is contained in:
@@ -902,13 +902,13 @@ expression_tree dot(LTYPE const & x, RTYPE const & y)\
|
||||
return expression_tree(invalid_node(), invalid_node(), op_element(UNARY_ARITHMETIC, INVALID_TYPE), &context, dtype, {0});\
|
||||
/*AXPY*/\
|
||||
if(numgt1(xs)==0 || numgt1(ys)==0)\
|
||||
return x*y;\
|
||||
return ravel(x*y);\
|
||||
/*Outer product*/\
|
||||
if(xs.back()==1 && ys.front()==1)\
|
||||
return x*y;\
|
||||
/*Inner product*/\
|
||||
if(numgt1(xs)==1 && numgt1(ys)==1)\
|
||||
return sum(trans(x)*y);\
|
||||
return sum(ravel(x)*ravel(y));\
|
||||
/*Matrix-Vector*/\
|
||||
if(numgt1(xs)==2 && numgt1(ys)==1){\
|
||||
if(rshape.size()>1)\
|
||||
@@ -1008,8 +1008,7 @@ void copy(void const * data, array_base& x, driver::CommandQueue & queue, bool b
|
||||
void copy(array_base const & x, void* data, driver::CommandQueue & queue, bool blocking)
|
||||
{
|
||||
unsigned int dtypesize = size_of(x.dtype());
|
||||
if(x.start()==0 && prod(x.stride())==prod(x.shape()))
|
||||
{
|
||||
if(x.start()==0 && prod(x.stride())==prod(x.shape())){
|
||||
queue.read(x.data(), blocking, 0, prod(x.shape())*dtypesize, data);
|
||||
}
|
||||
else
|
||||
|
Reference in New Issue
Block a user