BLAS: Removed nasty temporary in GEMM...
This commit is contained in:
@@ -81,6 +81,7 @@ struct array_holder
|
|||||||
{
|
{
|
||||||
int_t start;
|
int_t start;
|
||||||
handle_t handle;
|
handle_t handle;
|
||||||
|
array_base* base;
|
||||||
};
|
};
|
||||||
|
|
||||||
class expression_tree
|
class expression_tree
|
||||||
|
@@ -47,6 +47,7 @@ expression_tree::node::node(value_scalar const & x) : type(VALUE_SCALAR_TYPE), d
|
|||||||
expression_tree::node::node(array_base const & x) : type(DENSE_ARRAY_TYPE), dtype(x.dtype()), shape(x.shape())
|
expression_tree::node::node(array_base const & x) : type(DENSE_ARRAY_TYPE), dtype(x.dtype()), shape(x.shape())
|
||||||
{
|
{
|
||||||
array.start = x.start();
|
array.start = x.start();
|
||||||
|
array.base = (array_base*)&x;
|
||||||
driver::Buffer::handle_type const & h = x.data().handle();
|
driver::Buffer::handle_type const & h = x.data().handle();
|
||||||
switch(h.backend()){
|
switch(h.backend()){
|
||||||
case driver::OPENCL: array.handle.cl = h.cl(); break;
|
case driver::OPENCL: array.handle.cl = h.cl(); break;
|
||||||
|
@@ -119,7 +119,7 @@ matrix_product::args matrix_product::check(expression_tree::data_type const & tr
|
|||||||
}
|
}
|
||||||
if(result.C == NULL)
|
if(result.C == NULL)
|
||||||
result.C = &left;
|
result.C = &left;
|
||||||
else if(result.C != &left)
|
else if(result.C->array.base != left.array.base)
|
||||||
result.C = NULL;
|
result.C = NULL;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user