From add123da110ba1b7a01e6b79aab20b1509011bbd Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 1 Jul 2016 19:07:15 -0700 Subject: [PATCH] API: Safer forwarding of dot() --- lib/array.cpp | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/lib/array.cpp b/lib/array.cpp index 74c4d8b40..c9e09e0f2 100644 --- a/lib/array.cpp +++ b/lib/array.cpp @@ -900,30 +900,13 @@ expression_tree dot(LTYPE const & x, RTYPE const & y)\ /*Empty result*/\ if(xs.front()==0 || ys.back()==0)\ return expression_tree(invalid_node(), invalid_node(), op_element(UNARY_ARITHMETIC, INVALID_TYPE), &context, dtype, {0});\ - /*AXPY*/\ - if(numgt1(xs)==0 || numgt1(ys)==0)\ - return x*y;\ - /*Outer product*/\ - if(xs.back()==1 && ys.front()==1)\ - return x*y;\ - /*Inner product*/\ - if(numgt1(xs)==1 && numgt1(ys)==1)\ - return sum(ravel(x)*ravel(y));\ - /*Matrix-Vector*/\ - if(numgt1(xs)==2 && numgt1(ys)==1){\ - if(rshape.size()>1)\ - return reshape(detail::matvecprod(x, y), rshape);\ - else\ - return detail::matvecprod(x, y);\ - }\ - if(numgt1(xs)==1 && numgt1(ys)==2){\ - if(rshape.size()>1)\ - return reshape(detail::matvecprod(trans(y), ravel(x)), rshape);\ - else\ - return detail::matvecprod(trans(y), ravel(x));\ - }\ - else /*if(numgt1(x)==2 && numgt1(y)==2)*/\ - return detail::matmatprod(x, y);\ + if(xs.size()==1 && ys.size()==1)\ + return sum(x*y);\ + if(xs.size()==2 && ys.size()==1)\ + return detail::matvecprod(x, y);\ + if(xs.size()==1 && ys.size()==2)\ + return detail::matvecprod(trans(y), x);\ + return detail::matmatprod(x, y);\ }