API: Safer forwarding of dot()

This commit is contained in:
Philippe Tillet
2016-07-01 19:07:15 -07:00
parent faea220464
commit add123da11

View File

@@ -900,30 +900,13 @@ expression_tree dot(LTYPE const & x, RTYPE const & y)\
/*Empty result*/\
if(xs.front()==0 || ys.back()==0)\
return expression_tree(invalid_node(), invalid_node(), op_element(UNARY_ARITHMETIC, INVALID_TYPE), &context, dtype, {0});\
/*AXPY*/\
if(numgt1(xs)==0 || numgt1(ys)==0)\
return x*y;\
/*Outer product*/\
if(xs.back()==1 && ys.front()==1)\
return x*y;\
/*Inner product*/\
if(numgt1(xs)==1 && numgt1(ys)==1)\
return sum(ravel(x)*ravel(y));\
/*Matrix-Vector*/\
if(numgt1(xs)==2 && numgt1(ys)==1){\
if(rshape.size()>1)\
return reshape(detail::matvecprod(x, y), rshape);\
else\
return detail::matvecprod(x, y);\
}\
if(numgt1(xs)==1 && numgt1(ys)==2){\
if(rshape.size()>1)\
return reshape(detail::matvecprod(trans(y), ravel(x)), rshape);\
else\
return detail::matvecprod(trans(y), ravel(x));\
}\
else /*if(numgt1(x)==2 && numgt1(y)==2)*/\
return detail::matmatprod(x, y);\
if(xs.size()==1 && ys.size()==1)\
return sum(x*y);\
if(xs.size()==2 && ys.size()==1)\
return detail::matvecprod(x, y);\
if(xs.size()==1 && ys.size()==2)\
return detail::matvecprod(trans(y), x);\
return detail::matmatprod(x, y);\
}