79 lines
2.5 KiB
C++
79 lines
2.5 KiB
C++
#include "isaac/symbolic/preset.h"
|
|
|
|
namespace isaac
|
|
{
|
|
|
|
namespace symbolic
|
|
{
|
|
|
|
namespace preset
|
|
{
|
|
|
|
void gemm::handle_node(array_expression::container_type &tree, size_t rootidx, args & a)
|
|
{
|
|
//Matrix-Matrix product node
|
|
if(tree[rootidx].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
|
|
{
|
|
if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY) a.A = &tree[rootidx].lhs;
|
|
if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs;
|
|
switch(tree[rootidx].op.type)
|
|
{
|
|
case OPERATOR_MATRIX_PRODUCT_NN_TYPE: a.type = MATRIX_PRODUCT_NN_TYPE; break;
|
|
case OPERATOR_MATRIX_PRODUCT_NT_TYPE: a.type = MATRIX_PRODUCT_NT_TYPE; break;
|
|
case OPERATOR_MATRIX_PRODUCT_TN_TYPE: a.type = MATRIX_PRODUCT_TN_TYPE; break;
|
|
case OPERATOR_MATRIX_PRODUCT_TT_TYPE: a.type = MATRIX_PRODUCT_TT_TYPE; break;
|
|
default: break;
|
|
}
|
|
}
|
|
|
|
//Scalar multiplication node
|
|
if(tree[rootidx].op.type==OPERATOR_MULT_TYPE)
|
|
{
|
|
//alpha*PROD
|
|
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY
|
|
&& tree[tree[rootidx].rhs.node_index].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
|
|
{
|
|
a.alpha = &tree[rootidx].lhs;
|
|
handle_node(tree, tree[rootidx].rhs.node_index, a);
|
|
}
|
|
|
|
//beta*C
|
|
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY)
|
|
{
|
|
a.beta = &tree[rootidx].lhs;
|
|
a.C = &tree[rootidx].rhs;
|
|
}
|
|
}
|
|
}
|
|
|
|
gemm::args gemm::check(array_expression::container_type & tree, size_t rootidx)
|
|
{
|
|
lhs_rhs_element * assigned = &tree[rootidx].lhs;
|
|
gemm::args result ;
|
|
if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
|
{
|
|
rootidx = tree[rootidx].rhs.node_index;
|
|
//Form X + Y
|
|
if(tree[rootidx].op.type==OPERATOR_ADD_TYPE || tree[rootidx].op.type==OPERATOR_SUB_TYPE)
|
|
{
|
|
if(tree[rootidx].lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
|
handle_node(tree, tree[rootidx].lhs.node_index, result);
|
|
if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
|
|
handle_node(tree, tree[rootidx].rhs.node_index, result);
|
|
}
|
|
else
|
|
handle_node(tree, rootidx, result);
|
|
}
|
|
if(result.C == NULL)
|
|
result.C = assigned;
|
|
else if(result.C->array != assigned->array)
|
|
result.C = NULL;
|
|
return result;
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|