Backend: Now not creating a temporary upon C = alpha*dot(op(A), op(B)) + beta*C

This commit is contained in:
Philippe Tillet
2015-06-27 17:55:01 -07:00
parent 3525edd54c
commit 0e207e7ca4
7 changed files with 206 additions and 67 deletions

View File

@@ -2,6 +2,8 @@
#define ISAAC_BACKEND_TEMPLATES_MPRODUCT_H
#include "isaac/backend/templates/base.h"
#include "isaac/symbolic/expression.h"
#include "isaac/symbolic/preset.h"
namespace isaac
{
@@ -46,7 +48,7 @@ private:
void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, array const & A, array const & B, array const & C,
value_scalar const &alpha, value_scalar const &beta, driver::Program & program, const char * suffix, execution_options_type const & options);
array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap);
std::vector<int_t> infos(expressions_tuple const & expressions, lhs_rhs_element *&C, lhs_rhs_element *&A, lhs_rhs_element *&B, lhs_rhs_element *&alpha, lhs_rhs_element *&beta);
std::vector<int_t> infos(expressions_tuple const & expressions, isaac::symbolic::preset::gemm::args &arguments);
public:
mproduct(mproduct::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans);
std::vector<int_t> input_sizes(expressions_tuple const & expressions);

View File

@@ -0,0 +1,49 @@
#ifndef ISAAC_SYMBOLIC_PRESET_H_
#define ISAAC_SYMBOLIC_PRESET_H_
#include "isaac/symbolic/expression.h"
namespace isaac
{
namespace symbolic
{
namespace preset
{
class gemm
{
public:
struct args
{
args(): alpha(NULL), A(NULL), B(NULL), beta(NULL), C(NULL), type(INVALID_EXPRESSION_TYPE){ }
lhs_rhs_element* alpha;
lhs_rhs_element* A;
lhs_rhs_element* B;
lhs_rhs_element* beta;
lhs_rhs_element* C;
expression_type type;
operator bool() const
{
return type!=INVALID_EXPRESSION_TYPE && C!=NULL;
}
};
private:
static void handle_node(array_expression::container_type &tree, size_t rootidx, args & a);
public:
static args check(array_expression::container_type & tree, size_t rootidx);
};
}
}
}
#endif

View File

@@ -127,6 +127,7 @@ template<> struct to_numeric_type<double> { static const numeric_type value = DO
enum expression_type
{
INVALID_EXPRESSION_TYPE,
SCALAR_AXPY_TYPE,
VECTOR_AXPY_TYPE,
MATRIX_AXPY_TYPE,
@@ -136,8 +137,7 @@ enum expression_type
MATRIX_PRODUCT_NN_TYPE,
MATRIX_PRODUCT_TN_TYPE,
MATRIX_PRODUCT_NT_TYPE,
MATRIX_PRODUCT_TT_TYPE,
INVALID_EXPRESSION_TYPE
MATRIX_PRODUCT_TT_TYPE
};
struct slice