Backend: Now not creating a temporary upon C = alpha*dot(op(A), op(B)) + beta*C
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
#define ISAAC_BACKEND_TEMPLATES_MPRODUCT_H
|
||||
|
||||
#include "isaac/backend/templates/base.h"
|
||||
#include "isaac/symbolic/expression.h"
|
||||
#include "isaac/symbolic/preset.h"
|
||||
|
||||
namespace isaac
|
||||
{
|
||||
@@ -46,7 +48,7 @@ private:
|
||||
void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, array const & A, array const & B, array const & C,
|
||||
value_scalar const &alpha, value_scalar const &beta, driver::Program & program, const char * suffix, execution_options_type const & options);
|
||||
array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap);
|
||||
std::vector<int_t> infos(expressions_tuple const & expressions, lhs_rhs_element *&C, lhs_rhs_element *&A, lhs_rhs_element *&B, lhs_rhs_element *&alpha, lhs_rhs_element *&beta);
|
||||
std::vector<int_t> infos(expressions_tuple const & expressions, isaac::symbolic::preset::gemm::args &arguments);
|
||||
public:
|
||||
mproduct(mproduct::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans);
|
||||
std::vector<int_t> input_sizes(expressions_tuple const & expressions);
|
||||
|
49
include/isaac/symbolic/preset.h
Normal file
49
include/isaac/symbolic/preset.h
Normal file
@@ -0,0 +1,49 @@
|
||||
#ifndef ISAAC_SYMBOLIC_PRESET_H_
|
||||
#define ISAAC_SYMBOLIC_PRESET_H_
|
||||
|
||||
#include "isaac/symbolic/expression.h"
|
||||
|
||||
namespace isaac
|
||||
{
|
||||
|
||||
namespace symbolic
|
||||
{
|
||||
|
||||
namespace preset
|
||||
{
|
||||
|
||||
|
||||
class gemm
|
||||
{
|
||||
|
||||
public:
|
||||
struct args
|
||||
{
|
||||
args(): alpha(NULL), A(NULL), B(NULL), beta(NULL), C(NULL), type(INVALID_EXPRESSION_TYPE){ }
|
||||
lhs_rhs_element* alpha;
|
||||
lhs_rhs_element* A;
|
||||
lhs_rhs_element* B;
|
||||
lhs_rhs_element* beta;
|
||||
lhs_rhs_element* C;
|
||||
expression_type type;
|
||||
|
||||
operator bool() const
|
||||
{
|
||||
return type!=INVALID_EXPRESSION_TYPE && C!=NULL;
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
static void handle_node(array_expression::container_type &tree, size_t rootidx, args & a);
|
||||
|
||||
public:
|
||||
static args check(array_expression::container_type & tree, size_t rootidx);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -127,6 +127,7 @@ template<> struct to_numeric_type<double> { static const numeric_type value = DO
|
||||
|
||||
enum expression_type
|
||||
{
|
||||
INVALID_EXPRESSION_TYPE,
|
||||
SCALAR_AXPY_TYPE,
|
||||
VECTOR_AXPY_TYPE,
|
||||
MATRIX_AXPY_TYPE,
|
||||
@@ -136,8 +137,7 @@ enum expression_type
|
||||
MATRIX_PRODUCT_NN_TYPE,
|
||||
MATRIX_PRODUCT_TN_TYPE,
|
||||
MATRIX_PRODUCT_NT_TYPE,
|
||||
MATRIX_PRODUCT_TT_TYPE,
|
||||
INVALID_EXPRESSION_TYPE
|
||||
MATRIX_PRODUCT_TT_TYPE
|
||||
};
|
||||
|
||||
struct slice
|
||||
|
Reference in New Issue
Block a user