API: Fixes more issues in dot() corner cases

This commit is contained in:
Philippe Tillet
2016-07-01 00:08:59 -07:00
parent 8834ec9fe2
commit dbfaef8886
2 changed files with 4 additions and 6 deletions

View File

@@ -902,13 +902,13 @@ expression_tree dot(LTYPE const & x, RTYPE const & y)\
return expression_tree(invalid_node(), invalid_node(), op_element(UNARY_ARITHMETIC, INVALID_TYPE), &context, dtype, {0});\
/*AXPY*/\
if(numgt1(xs)==0 || numgt1(ys)==0)\
return x*y;\
return ravel(x*y);\
/*Outer product*/\
if(xs.back()==1 && ys.front()==1)\
return x*y;\
/*Inner product*/\
if(numgt1(xs)==1 && numgt1(ys)==1)\
return sum(trans(x)*y);\
return sum(ravel(x)*ravel(y));\
/*Matrix-Vector*/\
if(numgt1(xs)==2 && numgt1(ys)==1){\
if(rshape.size()>1)\
@@ -1008,8 +1008,7 @@ void copy(void const * data, array_base& x, driver::CommandQueue & queue, bool b
void copy(array_base const & x, void* data, driver::CommandQueue & queue, bool blocking)
{
unsigned int dtypesize = size_of(x.dtype());
if(x.start()==0 && prod(x.stride())==prod(x.shape()))
{
if(x.start()==0 && prod(x.stride())==prod(x.shape())){
queue.read(x.data(), blocking, 0, prod(x.shape())*dtypesize, data);
}
else