Backend: Now not creating a temporary upon C = alpha*dot(op(A), op(B)) + beta*C
This commit is contained in:
@@ -2,6 +2,8 @@
|
|||||||
#define ISAAC_BACKEND_TEMPLATES_MPRODUCT_H
|
#define ISAAC_BACKEND_TEMPLATES_MPRODUCT_H
|
||||||
|
|
||||||
#include "isaac/backend/templates/base.h"
|
#include "isaac/backend/templates/base.h"
|
||||||
|
#include "isaac/symbolic/expression.h"
|
||||||
|
#include "isaac/symbolic/preset.h"
|
||||||
|
|
||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
@@ -46,7 +48,7 @@ private:
|
|||||||
void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, array const & A, array const & B, array const & C,
|
void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, array const & A, array const & B, array const & C,
|
||||||
value_scalar const &alpha, value_scalar const &beta, driver::Program & program, const char * suffix, execution_options_type const & options);
|
value_scalar const &alpha, value_scalar const &beta, driver::Program & program, const char * suffix, execution_options_type const & options);
|
||||||
array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap);
|
array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap);
|
||||||
std::vector<int_t> infos(expressions_tuple const & expressions, lhs_rhs_element *&C, lhs_rhs_element *&A, lhs_rhs_element *&B, lhs_rhs_element *&alpha, lhs_rhs_element *&beta);
|
std::vector<int_t> infos(expressions_tuple const & expressions, isaac::symbolic::preset::gemm::args &arguments);
|
||||||
public:
|
public:
|
||||||
mproduct(mproduct::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans);
|
mproduct(mproduct::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans);
|
||||||
std::vector<int_t> input_sizes(expressions_tuple const & expressions);
|
std::vector<int_t> input_sizes(expressions_tuple const & expressions);
|
||||||
|
49
include/isaac/symbolic/preset.h
Normal file
49
include/isaac/symbolic/preset.h
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
#ifndef ISAAC_SYMBOLIC_PRESET_H_
|
||||||
|
#define ISAAC_SYMBOLIC_PRESET_H_
|
||||||
|
|
||||||
|
#include "isaac/symbolic/expression.h"
|
||||||
|
|
||||||
|
namespace isaac
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace symbolic
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace preset
|
||||||
|
{
|
||||||
|
|
||||||
|
|
||||||
|
class gemm
|
||||||
|
{
|
||||||
|
|
||||||
|
public:
|
||||||
|
struct args
|
||||||
|
{
|
||||||
|
args(): alpha(NULL), A(NULL), B(NULL), beta(NULL), C(NULL), type(INVALID_EXPRESSION_TYPE){ }
|
||||||
|
lhs_rhs_element* alpha;
|
||||||
|
lhs_rhs_element* A;
|
||||||
|
lhs_rhs_element* B;
|
||||||
|
lhs_rhs_element* beta;
|
||||||
|
lhs_rhs_element* C;
|
||||||
|
expression_type type;
|
||||||
|
|
||||||
|
operator bool() const
|
||||||
|
{
|
||||||
|
return type!=INVALID_EXPRESSION_TYPE && C!=NULL;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
static void handle_node(array_expression::container_type &tree, size_t rootidx, args & a);
|
||||||
|
|
||||||
|
public:
|
||||||
|
static args check(array_expression::container_type & tree, size_t rootidx);
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@@ -127,6 +127,7 @@ template<> struct to_numeric_type<double> { static const numeric_type value = DO
|
|||||||
|
|
||||||
enum expression_type
|
enum expression_type
|
||||||
{
|
{
|
||||||
|
INVALID_EXPRESSION_TYPE,
|
||||||
SCALAR_AXPY_TYPE,
|
SCALAR_AXPY_TYPE,
|
||||||
VECTOR_AXPY_TYPE,
|
VECTOR_AXPY_TYPE,
|
||||||
MATRIX_AXPY_TYPE,
|
MATRIX_AXPY_TYPE,
|
||||||
@@ -136,8 +137,7 @@ enum expression_type
|
|||||||
MATRIX_PRODUCT_NN_TYPE,
|
MATRIX_PRODUCT_NN_TYPE,
|
||||||
MATRIX_PRODUCT_TN_TYPE,
|
MATRIX_PRODUCT_TN_TYPE,
|
||||||
MATRIX_PRODUCT_NT_TYPE,
|
MATRIX_PRODUCT_NT_TYPE,
|
||||||
MATRIX_PRODUCT_TT_TYPE,
|
MATRIX_PRODUCT_TT_TYPE
|
||||||
INVALID_EXPRESSION_TYPE
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct slice
|
struct slice
|
||||||
|
@@ -1,10 +1,12 @@
|
|||||||
#include "isaac/array.h"
|
#include "isaac/array.h"
|
||||||
#include "isaac/backend/templates/mproduct.h"
|
#include "isaac/backend/templates/mproduct.h"
|
||||||
#include "isaac/backend/keywords.h"
|
#include "isaac/backend/keywords.h"
|
||||||
|
#include "isaac/model/model.h"
|
||||||
|
#include "isaac/symbolic/preset.h"
|
||||||
#include "isaac/tools/make_vector.hpp"
|
#include "isaac/tools/make_vector.hpp"
|
||||||
#include "isaac/tools/to_string.hpp"
|
#include "isaac/tools/to_string.hpp"
|
||||||
#include "isaac/tools/miscellaneous.hpp"
|
#include "isaac/tools/miscellaneous.hpp"
|
||||||
#include "isaac/model/model.h"
|
|
||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
|
||||||
@@ -640,17 +642,15 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
return array(M, s0, s1);
|
return array(M, s0, s1);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int_t> mproduct::infos(expressions_tuple const & expressions, lhs_rhs_element *& C, lhs_rhs_element *& A, lhs_rhs_element *& B, lhs_rhs_element *&, lhs_rhs_element *&)
|
std::vector<int_t> mproduct::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments)
|
||||||
{
|
{
|
||||||
isaac::array_expression & array_expression = (*expressions.data().front());
|
isaac::array_expression & array_expression = (*expressions.data().front());
|
||||||
array_expression::container_type & array = array_expression.tree();
|
array_expression::container_type & array = array_expression.tree();
|
||||||
std::size_t root = array_expression.root();
|
std::size_t root = array_expression.root();
|
||||||
C = &array[root].lhs;
|
arguments = symbolic::preset::gemm::check(array, root);
|
||||||
A = &array[array[root].rhs.node_index].lhs;
|
int_t M = arguments.C->array->shape()[0];
|
||||||
B = &array[array[root].rhs.node_index].rhs;
|
int_t N = arguments.C->array->shape()[1];
|
||||||
int_t M = C->array->shape()[0];
|
int_t K = (A_trans_=='T')?arguments.A->array->shape()[0]:arguments.A->array->shape()[1];
|
||||||
int_t N = C->array->shape()[1];
|
|
||||||
int_t K = (A_trans_=='T')?A->array->shape()[0]:A->array->shape()[1];
|
|
||||||
return {M, N, K};
|
return {M, N, K};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -665,8 +665,8 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
|
|
||||||
std::vector<int_t> mproduct::input_sizes(expressions_tuple const & expressions)
|
std::vector<int_t> mproduct::input_sizes(expressions_tuple const & expressions)
|
||||||
{
|
{
|
||||||
lhs_rhs_element *d0, *d1, *d2, *d3, *d4;
|
symbolic::preset::gemm::args dummy;
|
||||||
return infos(expressions, d0, d1, d2, d3, d4);
|
return infos(expressions, dummy);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mproduct::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
|
void mproduct::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
|
||||||
@@ -677,8 +677,8 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
expressions_tuple const & expressions = ctr.x();
|
expressions_tuple const & expressions = ctr.x();
|
||||||
|
|
||||||
|
|
||||||
lhs_rhs_element * eC, *eA, *eB, *ealpha, *ebeta;
|
symbolic::preset::gemm::args args;
|
||||||
std::vector<int_t> MNK = infos(expressions, eC, eA, eB, ealpha, ebeta);
|
std::vector<int_t> MNK = infos(expressions, args);
|
||||||
|
|
||||||
int_t M = MNK[0];
|
int_t M = MNK[0];
|
||||||
int_t N = MNK[1];
|
int_t N = MNK[1];
|
||||||
@@ -688,9 +688,9 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
if(M==0 || N == 0 || K ==0)
|
if(M==0 || N == 0 || K ==0)
|
||||||
return;
|
return;
|
||||||
//Extract
|
//Extract
|
||||||
array * pA = eA->array;
|
array * pA = args.A->array;
|
||||||
array * pB = eB->array;
|
array * pB = args.B->array;
|
||||||
array * pC = eC->array;
|
array * pC = args.C->array;
|
||||||
|
|
||||||
//Check if requires fallback
|
//Check if requires fallback
|
||||||
int_t ldstrideA = pA->stride()[0];
|
int_t ldstrideA = pA->stride()[0];
|
||||||
@@ -699,7 +699,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
int_t ldstartA = pA->start()[0];
|
int_t ldstartA = pA->start()[0];
|
||||||
int_t ldstartB = pB->start()[0];
|
int_t ldstartB = pB->start()[0];
|
||||||
|
|
||||||
numeric_type dtype = eC->dtype;
|
numeric_type dtype = args.C->dtype;
|
||||||
|
|
||||||
//Enqueue
|
//Enqueue
|
||||||
bool swap_A = (A_trans_=='T');
|
bool swap_A = (A_trans_=='T');
|
||||||
@@ -708,7 +708,6 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
|
|
||||||
value_scalar beta(0, dtype);
|
value_scalar beta(0, dtype);
|
||||||
value_scalar alpha(1, dtype);
|
value_scalar alpha(1, dtype);
|
||||||
value_scalar _1(1, dtype);
|
|
||||||
|
|
||||||
|
|
||||||
execution_options_type const & options = ctr.execution_options();
|
execution_options_type const & options = ctr.execution_options();
|
||||||
@@ -724,6 +723,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
int_t lM = M / p_.mL * p_.mL;
|
int_t lM = M / p_.mL * p_.mL;
|
||||||
int_t lN = N / p_.nL * p_.nL;
|
int_t lN = N / p_.nL * p_.nL;
|
||||||
int_t lK = K / (p_.kL*p_.depth) * p_.kL*p_.depth;
|
int_t lK = K / (p_.kL*p_.depth) * p_.kL*p_.depth;
|
||||||
|
value_scalar _1(1, dtype);
|
||||||
|
|
||||||
enqueue_block(queue, lM, lN, lK, create_slice(*pA, 0, lM, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, beta, program, suffix, options);
|
enqueue_block(queue, lM, lN, lK, create_slice(*pA, 0, lM, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, beta, program, suffix, options);
|
||||||
fallback.enqueue_block(queue, lM, lN, K - lK, create_slice(*pA, 0, lM, lK, K, swap_A), create_slice(*pB, lK, K, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, _1, program, "fallback", options);
|
fallback.enqueue_block(queue, lM, lN, K - lK, create_slice(*pA, 0, lM, lK, K, swap_A), create_slice(*pB, lK, K, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, _1, program, "fallback", options);
|
||||||
|
@@ -7,6 +7,7 @@
|
|||||||
#include <CL/cl.hpp>
|
#include <CL/cl.hpp>
|
||||||
#include "isaac/model/model.h"
|
#include "isaac/model/model.h"
|
||||||
#include "isaac/symbolic/expression.h"
|
#include "isaac/symbolic/expression.h"
|
||||||
|
#include "isaac/symbolic/preset.h"
|
||||||
|
|
||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
@@ -157,7 +158,6 @@ namespace isaac
|
|||||||
parse(array, node.rhs.node_index, next_type, breakpoints, final_type, false);
|
parse(array, node.rhs.node_index, next_type, breakpoints, final_type, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** @brief Executes a array_expression on the given models map*/
|
/** @brief Executes a array_expression on the given models map*/
|
||||||
@@ -172,61 +172,70 @@ namespace isaac
|
|||||||
//Todo: technically the datatype should be per temporary
|
//Todo: technically the datatype should be per temporary
|
||||||
numeric_type dtype = root_save.lhs.dtype;
|
numeric_type dtype = root_save.lhs.dtype;
|
||||||
|
|
||||||
detail::breakpoints_t breakpoints;
|
expression_type final_type;
|
||||||
breakpoints.reserve(8);
|
//GEMM
|
||||||
|
if(symbolic::preset::gemm::args args = symbolic::preset::gemm::check(tree, rootidx)){
|
||||||
//Init
|
final_type = args.type;
|
||||||
expression_type current_type;
|
}
|
||||||
if(root_save.lhs.array->nshape()==0)
|
//Default
|
||||||
current_type = SCALAR_AXPY_TYPE;
|
|
||||||
else if(root_save.lhs.array->nshape()==1)
|
|
||||||
current_type=VECTOR_AXPY_TYPE;
|
|
||||||
else
|
else
|
||||||
current_type=MATRIX_AXPY_TYPE;
|
|
||||||
expression_type final_type = current_type;
|
|
||||||
|
|
||||||
/*----Parse required temporaries-----*/
|
|
||||||
detail::parse(tree, rootidx, current_type, breakpoints, final_type);
|
|
||||||
std::vector<tools::shared_ptr<array> > temporaries_;
|
|
||||||
|
|
||||||
/*----Compute required temporaries----*/
|
|
||||||
for(detail::breakpoints_t::reverse_iterator rit = breakpoints.rbegin() ; rit != breakpoints.rend() ; ++rit)
|
|
||||||
{
|
{
|
||||||
tools::shared_ptr<model> const & pmodel = models[std::make_pair(rit->first, dtype)];
|
detail::breakpoints_t breakpoints;
|
||||||
array_expression::node const & node = tree[rit->second->node_index];
|
breakpoints.reserve(8);
|
||||||
array_expression::node const & lmost = lhs_most(tree, node);
|
|
||||||
|
|
||||||
//Creates temporary
|
//Init
|
||||||
tools::shared_ptr<array> tmp;
|
expression_type current_type;
|
||||||
switch(rit->first){
|
if(root_save.lhs.array->nshape()==0)
|
||||||
case SCALAR_AXPY_TYPE:
|
current_type = SCALAR_AXPY_TYPE;
|
||||||
case REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
|
else if(root_save.lhs.array->nshape()==1)
|
||||||
|
current_type=VECTOR_AXPY_TYPE;
|
||||||
|
else
|
||||||
|
current_type=MATRIX_AXPY_TYPE;
|
||||||
|
final_type = current_type;
|
||||||
|
|
||||||
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
/*----Parse required temporaries-----*/
|
||||||
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
detail::parse(tree, rootidx, current_type, breakpoints, final_type);
|
||||||
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
|
std::vector<tools::shared_ptr<array> > temporaries_;
|
||||||
|
|
||||||
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
|
/*----Compute required temporaries----*/
|
||||||
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
|
for(detail::breakpoints_t::reverse_iterator rit = breakpoints.rbegin() ; rit != breakpoints.rend() ; ++rit)
|
||||||
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
|
{
|
||||||
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
|
tools::shared_ptr<model> const & pmodel = models[std::make_pair(rit->first, dtype)];
|
||||||
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
|
array_expression::node const & node = tree[rit->second->node_index];
|
||||||
|
array_expression::node const & lmost = lhs_most(tree, node);
|
||||||
|
|
||||||
default: throw std::invalid_argument("Unrecognized operation");
|
//Creates temporary
|
||||||
}
|
tools::shared_ptr<array> tmp;
|
||||||
temporaries_.push_back(tmp);
|
switch(rit->first){
|
||||||
|
case SCALAR_AXPY_TYPE:
|
||||||
|
case REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
|
||||||
|
|
||||||
tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE;
|
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||||
fill(tree[rootidx].lhs, (array&)*tmp);
|
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||||
tree[rootidx].rhs = *rit->second;
|
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||||
tree[rootidx].rhs.type_family = rit->second->type_family;
|
|
||||||
|
|
||||||
//Execute
|
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||||
pmodel->execute(controller<expressions_tuple>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
|
||||||
tree[rootidx] = root_save;
|
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
|
||||||
|
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
|
||||||
|
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
|
||||||
|
|
||||||
//Incorporates the temporary within the array_expression
|
default: throw std::invalid_argument("Unrecognized operation");
|
||||||
fill(*rit->second, (array&)*tmp);
|
}
|
||||||
|
temporaries_.push_back(tmp);
|
||||||
|
|
||||||
|
tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE;
|
||||||
|
fill(tree[rootidx].lhs, (array&)*tmp);
|
||||||
|
tree[rootidx].rhs = *rit->second;
|
||||||
|
tree[rootidx].rhs.type_family = rit->second->type_family;
|
||||||
|
|
||||||
|
//Execute
|
||||||
|
pmodel->execute(controller<expressions_tuple>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));
|
||||||
|
tree[rootidx] = root_save;
|
||||||
|
|
||||||
|
//Incorporates the temporary within the array_expression
|
||||||
|
fill(*rit->second, (array&)*tmp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*-----Compute final expression-----*/
|
/*-----Compute final expression-----*/
|
||||||
|
@@ -4,6 +4,7 @@
|
|||||||
#include "isaac/value_scalar.h"
|
#include "isaac/value_scalar.h"
|
||||||
#include <CL/cl.hpp>
|
#include <CL/cl.hpp>
|
||||||
#include "isaac/symbolic/expression.h"
|
#include "isaac/symbolic/expression.h"
|
||||||
|
#include "isaac/symbolic/preset.h"
|
||||||
|
|
||||||
namespace isaac
|
namespace isaac
|
||||||
{
|
{
|
||||||
|
78
lib/symbolic/preset.cpp
Normal file
78
lib/symbolic/preset.cpp
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
#include "isaac/symbolic/preset.h"
|
||||||
|
|
||||||
|
namespace isaac
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace symbolic
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace preset
|
||||||
|
{
|
||||||
|
|
||||||
|
void gemm::handle_node(array_expression::container_type &tree, size_t rootidx, args & a)
|
||||||
|
{
|
||||||
|
//Matrix-Matrix product node
|
||||||
|
if(tree[rootidx].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
|
||||||
|
{
|
||||||
|
if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY) a.A = &tree[rootidx].lhs;
|
||||||
|
if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs;
|
||||||
|
switch(tree[rootidx].op.type)
|
||||||
|
{
|
||||||
|
case OPERATOR_MATRIX_PRODUCT_NN_TYPE: a.type = MATRIX_PRODUCT_NN_TYPE; break;
|
||||||
|
case OPERATOR_MATRIX_PRODUCT_NT_TYPE: a.type = MATRIX_PRODUCT_NT_TYPE; break;
|
||||||
|
case OPERATOR_MATRIX_PRODUCT_TN_TYPE: a.type = MATRIX_PRODUCT_TN_TYPE; break;
|
||||||
|
case OPERATOR_MATRIX_PRODUCT_TT_TYPE: a.type = MATRIX_PRODUCT_TT_TYPE; break;
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Scalar multiplication node
|
||||||
|
if(tree[rootidx].op.type==OPERATOR_MULT_TYPE)
|
||||||
|
{
|
||||||
|
//alpha*PROD
|
||||||
|
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY
|
||||||
|
&& tree[tree[rootidx].rhs.node_index].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
|
||||||
|
{
|
||||||
|
a.alpha = &tree[rootidx].lhs;
|
||||||
|
handle_node(tree, tree[rootidx].rhs.node_index, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
//beta*C
|
||||||
|
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY)
|
||||||
|
{
|
||||||
|
a.beta = &tree[rootidx].lhs;
|
||||||
|
a.C = &tree[rootidx].rhs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gemm::args gemm::check(array_expression::container_type & tree, size_t rootidx)
|
||||||
|
{
|
||||||
|
lhs_rhs_element * assigned = &tree[rootidx].lhs;
|
||||||
|
gemm::args result ;
|
||||||
|
if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
||||||
|
{
|
||||||
|
rootidx = tree[rootidx].rhs.node_index;
|
||||||
|
//Form X + Y
|
||||||
|
if(tree[rootidx].op.type==OPERATOR_ADD_TYPE || tree[rootidx].op.type==OPERATOR_SUB_TYPE)
|
||||||
|
{
|
||||||
|
if(tree[rootidx].lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
||||||
|
handle_node(tree, tree[rootidx].lhs.node_index, result);
|
||||||
|
if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
||||||
|
handle_node(tree, tree[rootidx].rhs.node_index, result);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
handle_node(tree, rootidx, result);
|
||||||
|
}
|
||||||
|
if(result.C == NULL)
|
||||||
|
result.C = assigned;
|
||||||
|
else if(result.C->array != assigned->array)
|
||||||
|
result.C = NULL;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Reference in New Issue
Block a user