API: Safer forwarding of dot()
This commit is contained in:
@@ -900,30 +900,13 @@ expression_tree dot(LTYPE const & x, RTYPE const & y)\
|
|||||||
/*Empty result*/\
|
/*Empty result*/\
|
||||||
if(xs.front()==0 || ys.back()==0)\
|
if(xs.front()==0 || ys.back()==0)\
|
||||||
return expression_tree(invalid_node(), invalid_node(), op_element(UNARY_ARITHMETIC, INVALID_TYPE), &context, dtype, {0});\
|
return expression_tree(invalid_node(), invalid_node(), op_element(UNARY_ARITHMETIC, INVALID_TYPE), &context, dtype, {0});\
|
||||||
/*AXPY*/\
|
if(xs.size()==1 && ys.size()==1)\
|
||||||
if(numgt1(xs)==0 || numgt1(ys)==0)\
|
return sum(x*y);\
|
||||||
return x*y;\
|
if(xs.size()==2 && ys.size()==1)\
|
||||||
/*Outer product*/\
|
return detail::matvecprod(x, y);\
|
||||||
if(xs.back()==1 && ys.front()==1)\
|
if(xs.size()==1 && ys.size()==2)\
|
||||||
return x*y;\
|
return detail::matvecprod(trans(y), x);\
|
||||||
/*Inner product*/\
|
return detail::matmatprod(x, y);\
|
||||||
if(numgt1(xs)==1 && numgt1(ys)==1)\
|
|
||||||
return sum(ravel(x)*ravel(y));\
|
|
||||||
/*Matrix-Vector*/\
|
|
||||||
if(numgt1(xs)==2 && numgt1(ys)==1){\
|
|
||||||
if(rshape.size()>1)\
|
|
||||||
return reshape(detail::matvecprod(x, y), rshape);\
|
|
||||||
else\
|
|
||||||
return detail::matvecprod(x, y);\
|
|
||||||
}\
|
|
||||||
if(numgt1(xs)==1 && numgt1(ys)==2){\
|
|
||||||
if(rshape.size()>1)\
|
|
||||||
return reshape(detail::matvecprod(trans(y), ravel(x)), rshape);\
|
|
||||||
else\
|
|
||||||
return detail::matvecprod(trans(y), ravel(x));\
|
|
||||||
}\
|
|
||||||
else /*if(numgt1(x)==2 && numgt1(y)==2)*/\
|
|
||||||
return detail::matmatprod(x, y);\
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user