API: Safer forwarding of dot()
This commit is contained in:
@@ -900,29 +900,12 @@ expression_tree dot(LTYPE const & x, RTYPE const & y)\
|
||||
/*Empty result*/\
|
||||
if(xs.front()==0 || ys.back()==0)\
|
||||
return expression_tree(invalid_node(), invalid_node(), op_element(UNARY_ARITHMETIC, INVALID_TYPE), &context, dtype, {0});\
|
||||
/*AXPY*/\
|
||||
if(numgt1(xs)==0 || numgt1(ys)==0)\
|
||||
return x*y;\
|
||||
/*Outer product*/\
|
||||
if(xs.back()==1 && ys.front()==1)\
|
||||
return x*y;\
|
||||
/*Inner product*/\
|
||||
if(numgt1(xs)==1 && numgt1(ys)==1)\
|
||||
return sum(ravel(x)*ravel(y));\
|
||||
/*Matrix-Vector*/\
|
||||
if(numgt1(xs)==2 && numgt1(ys)==1){\
|
||||
if(rshape.size()>1)\
|
||||
return reshape(detail::matvecprod(x, y), rshape);\
|
||||
else\
|
||||
if(xs.size()==1 && ys.size()==1)\
|
||||
return sum(x*y);\
|
||||
if(xs.size()==2 && ys.size()==1)\
|
||||
return detail::matvecprod(x, y);\
|
||||
}\
|
||||
if(numgt1(xs)==1 && numgt1(ys)==2){\
|
||||
if(rshape.size()>1)\
|
||||
return reshape(detail::matvecprod(trans(y), ravel(x)), rshape);\
|
||||
else\
|
||||
return detail::matvecprod(trans(y), ravel(x));\
|
||||
}\
|
||||
else /*if(numgt1(x)==2 && numgt1(y)==2)*/\
|
||||
if(xs.size()==1 && ys.size()==2)\
|
||||
return detail::matvecprod(trans(y), x);\
|
||||
return detail::matmatprod(x, y);\
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user