diff --git a/include/isaac/backend/templates/mproduct.h b/include/isaac/backend/templates/mproduct.h index 1f87c66a1..ecef0724b 100644 --- a/include/isaac/backend/templates/mproduct.h +++ b/include/isaac/backend/templates/mproduct.h @@ -2,6 +2,8 @@ #define ISAAC_BACKEND_TEMPLATES_MPRODUCT_H #include "isaac/backend/templates/base.h" +#include "isaac/symbolic/expression.h" +#include "isaac/symbolic/preset.h" namespace isaac { @@ -46,7 +48,7 @@ private: void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, array const & A, array const & B, array const & C, value_scalar const &alpha, value_scalar const &beta, driver::Program & program, const char * suffix, execution_options_type const & options); array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap); - std::vector infos(expressions_tuple const & expressions, lhs_rhs_element *&C, lhs_rhs_element *&A, lhs_rhs_element *&B, lhs_rhs_element *&alpha, lhs_rhs_element *&beta); + std::vector infos(expressions_tuple const & expressions, isaac::symbolic::preset::gemm::args &arguments); public: mproduct(mproduct::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans); std::vector input_sizes(expressions_tuple const & expressions); diff --git a/include/isaac/symbolic/preset.h b/include/isaac/symbolic/preset.h new file mode 100644 index 000000000..2d75fb4da --- /dev/null +++ b/include/isaac/symbolic/preset.h @@ -0,0 +1,49 @@ +#ifndef ISAAC_SYMBOLIC_PRESET_H_ +#define ISAAC_SYMBOLIC_PRESET_H_ + +#include "isaac/symbolic/expression.h" + +namespace isaac +{ + +namespace symbolic +{ + +namespace preset +{ + + +class gemm +{ + +public: + struct args + { + args(): alpha(NULL), A(NULL), B(NULL), beta(NULL), C(NULL), type(INVALID_EXPRESSION_TYPE){ } + lhs_rhs_element* alpha; + lhs_rhs_element* A; + lhs_rhs_element* B; + lhs_rhs_element* beta; + lhs_rhs_element* C; + expression_type type; + + operator bool() const + { + return type!=INVALID_EXPRESSION_TYPE && C!=NULL; + } + }; + +private: + static void handle_node(array_expression::container_type &tree, size_t rootidx, args & a); + +public: + static args check(array_expression::container_type & tree, size_t rootidx); +}; + +} + +} + +} + +#endif diff --git a/include/isaac/types.h b/include/isaac/types.h index 6f90e8584..eaf999b8d 100644 --- a/include/isaac/types.h +++ b/include/isaac/types.h @@ -127,6 +127,7 @@ template<> struct to_numeric_type { static const numeric_type value = DO enum expression_type { + INVALID_EXPRESSION_TYPE, SCALAR_AXPY_TYPE, VECTOR_AXPY_TYPE, MATRIX_AXPY_TYPE, @@ -136,8 +137,7 @@ enum expression_type MATRIX_PRODUCT_NN_TYPE, MATRIX_PRODUCT_TN_TYPE, MATRIX_PRODUCT_NT_TYPE, - MATRIX_PRODUCT_TT_TYPE, - INVALID_EXPRESSION_TYPE + MATRIX_PRODUCT_TT_TYPE }; struct slice diff --git a/lib/backend/templates/mproduct.cpp b/lib/backend/templates/mproduct.cpp index 66cfb732e..4e353067c 100644 --- a/lib/backend/templates/mproduct.cpp +++ b/lib/backend/templates/mproduct.cpp @@ -1,10 +1,12 @@ #include "isaac/array.h" #include "isaac/backend/templates/mproduct.h" #include "isaac/backend/keywords.h" +#include "isaac/model/model.h" +#include "isaac/symbolic/preset.h" #include "isaac/tools/make_vector.hpp" #include "isaac/tools/to_string.hpp" #include "isaac/tools/miscellaneous.hpp" -#include "isaac/model/model.h" + namespace isaac { @@ -640,17 +642,15 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width return array(M, s0, s1); } - std::vector mproduct::infos(expressions_tuple const & expressions, lhs_rhs_element *& C, lhs_rhs_element *& A, lhs_rhs_element *& B, lhs_rhs_element *&, lhs_rhs_element *&) + std::vector mproduct::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments) { isaac::array_expression & array_expression = (*expressions.data().front()); array_expression::container_type & array = array_expression.tree(); std::size_t root = array_expression.root(); - C = &array[root].lhs; - A = &array[array[root].rhs.node_index].lhs; - B = &array[array[root].rhs.node_index].rhs; - int_t M = C->array->shape()[0]; - int_t N = C->array->shape()[1]; - int_t K = (A_trans_=='T')?A->array->shape()[0]:A->array->shape()[1]; + arguments = symbolic::preset::gemm::check(array, root); + int_t M = arguments.C->array->shape()[0]; + int_t N = arguments.C->array->shape()[1]; + int_t K = (A_trans_=='T')?arguments.A->array->shape()[0]:arguments.A->array->shape()[1]; return {M, N, K}; } @@ -665,8 +665,8 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width std::vector mproduct::input_sizes(expressions_tuple const & expressions) { - lhs_rhs_element *d0, *d1, *d2, *d3, *d4; - return infos(expressions, d0, d1, d2, d3, d4); + symbolic::preset::gemm::args dummy; + return infos(expressions, dummy); } void mproduct::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller const & ctr) @@ -677,8 +677,8 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width expressions_tuple const & expressions = ctr.x(); - lhs_rhs_element * eC, *eA, *eB, *ealpha, *ebeta; - std::vector MNK = infos(expressions, eC, eA, eB, ealpha, ebeta); + symbolic::preset::gemm::args args; + std::vector MNK = infos(expressions, args); int_t M = MNK[0]; int_t N = MNK[1]; @@ -688,9 +688,9 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width if(M==0 || N == 0 || K ==0) return; //Extract - array * pA = eA->array; - array * pB = eB->array; - array * pC = eC->array; + array * pA = args.A->array; + array * pB = args.B->array; + array * pC = args.C->array; //Check if requires fallback int_t ldstrideA = pA->stride()[0]; @@ -699,7 +699,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width int_t ldstartA = pA->start()[0]; int_t ldstartB = pB->start()[0]; - numeric_type dtype = eC->dtype; + numeric_type dtype = args.C->dtype; //Enqueue bool swap_A = (A_trans_=='T'); @@ -708,7 +708,6 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width value_scalar beta(0, dtype); value_scalar alpha(1, dtype); - value_scalar _1(1, dtype); execution_options_type const & options = ctr.execution_options(); @@ -724,6 +723,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width int_t lM = M / p_.mL * p_.mL; int_t lN = N / p_.nL * p_.nL; int_t lK = K / (p_.kL*p_.depth) * p_.kL*p_.depth; + value_scalar _1(1, dtype); enqueue_block(queue, lM, lN, lK, create_slice(*pA, 0, lM, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, beta, program, suffix, options); fallback.enqueue_block(queue, lM, lN, K - lK, create_slice(*pA, 0, lM, lK, K, swap_A), create_slice(*pB, lK, K, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, _1, program, "fallback", options); diff --git a/lib/symbolic/execute.cpp b/lib/symbolic/execute.cpp index fb2673967..89be6b00d 100644 --- a/lib/symbolic/execute.cpp +++ b/lib/symbolic/execute.cpp @@ -7,6 +7,7 @@ #include #include "isaac/model/model.h" #include "isaac/symbolic/expression.h" +#include "isaac/symbolic/preset.h" namespace isaac { @@ -157,7 +158,6 @@ namespace isaac parse(array, node.rhs.node_index, next_type, breakpoints, final_type, false); } } - } /** @brief Executes a array_expression on the given models map*/ @@ -172,61 +172,70 @@ namespace isaac //Todo: technically the datatype should be per temporary numeric_type dtype = root_save.lhs.dtype; - detail::breakpoints_t breakpoints; - breakpoints.reserve(8); - - //Init - expression_type current_type; - if(root_save.lhs.array->nshape()==0) - current_type = SCALAR_AXPY_TYPE; - else if(root_save.lhs.array->nshape()==1) - current_type=VECTOR_AXPY_TYPE; + expression_type final_type; + //GEMM + if(symbolic::preset::gemm::args args = symbolic::preset::gemm::check(tree, rootidx)){ + final_type = args.type; + } + //Default else - current_type=MATRIX_AXPY_TYPE; - expression_type final_type = current_type; - - /*----Parse required temporaries-----*/ - detail::parse(tree, rootidx, current_type, breakpoints, final_type); - std::vector > temporaries_; - - /*----Compute required temporaries----*/ - for(detail::breakpoints_t::reverse_iterator rit = breakpoints.rbegin() ; rit != breakpoints.rend() ; ++rit) { - tools::shared_ptr const & pmodel = models[std::make_pair(rit->first, dtype)]; - array_expression::node const & node = tree[rit->second->node_index]; - array_expression::node const & lmost = lhs_most(tree, node); + detail::breakpoints_t breakpoints; + breakpoints.reserve(8); - //Creates temporary - tools::shared_ptr tmp; - switch(rit->first){ - case SCALAR_AXPY_TYPE: - case REDUCTION_TYPE: tmp = tools::shared_ptr(new array(1, dtype, context)); break; + //Init + expression_type current_type; + if(root_save.lhs.array->nshape()==0) + current_type = SCALAR_AXPY_TYPE; + else if(root_save.lhs.array->nshape()==1) + current_type=VECTOR_AXPY_TYPE; + else + current_type=MATRIX_AXPY_TYPE; + final_type = current_type; - case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr(new array(lmost.lhs.array->shape()[0], dtype, context)); break; - case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr(new array(lmost.lhs.array->shape()[0], dtype, context)); break; - case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr(new array(lmost.lhs.array->shape()[1], dtype, context)); break; + /*----Parse required temporaries-----*/ + detail::parse(tree, rootidx, current_type, breakpoints, final_type); + std::vector > temporaries_; - case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break; - case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break; - case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break; - case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break; - case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break; + /*----Compute required temporaries----*/ + for(detail::breakpoints_t::reverse_iterator rit = breakpoints.rbegin() ; rit != breakpoints.rend() ; ++rit) + { + tools::shared_ptr const & pmodel = models[std::make_pair(rit->first, dtype)]; + array_expression::node const & node = tree[rit->second->node_index]; + array_expression::node const & lmost = lhs_most(tree, node); - default: throw std::invalid_argument("Unrecognized operation"); - } - temporaries_.push_back(tmp); + //Creates temporary + tools::shared_ptr tmp; + switch(rit->first){ + case SCALAR_AXPY_TYPE: + case REDUCTION_TYPE: tmp = tools::shared_ptr(new array(1, dtype, context)); break; - tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE; - fill(tree[rootidx].lhs, (array&)*tmp); - tree[rootidx].rhs = *rit->second; - tree[rootidx].rhs.type_family = rit->second->type_family; + case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr(new array(lmost.lhs.array->shape()[0], dtype, context)); break; + case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr(new array(lmost.lhs.array->shape()[0], dtype, context)); break; + case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr(new array(lmost.lhs.array->shape()[1], dtype, context)); break; - //Execute - pmodel->execute(controller(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options())); - tree[rootidx] = root_save; + case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break; + case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break; + case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break; + case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break; + case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break; - //Incorporates the temporary within the array_expression - fill(*rit->second, (array&)*tmp); + default: throw std::invalid_argument("Unrecognized operation"); + } + temporaries_.push_back(tmp); + + tree[rootidx].op.type = OPERATOR_ASSIGN_TYPE; + fill(tree[rootidx].lhs, (array&)*tmp); + tree[rootidx].rhs = *rit->second; + tree[rootidx].rhs.type_family = rit->second->type_family; + + //Execute + pmodel->execute(controller(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options())); + tree[rootidx] = root_save; + + //Incorporates the temporary within the array_expression + fill(*rit->second, (array&)*tmp); + } } /*-----Compute final expression-----*/ diff --git a/lib/symbolic/expression.cpp b/lib/symbolic/expression.cpp index f02b3d220..cec9dc259 100644 --- a/lib/symbolic/expression.cpp +++ b/lib/symbolic/expression.cpp @@ -4,6 +4,7 @@ #include "isaac/value_scalar.h" #include #include "isaac/symbolic/expression.h" +#include "isaac/symbolic/preset.h" namespace isaac { diff --git a/lib/symbolic/preset.cpp b/lib/symbolic/preset.cpp new file mode 100644 index 000000000..3aa59574f --- /dev/null +++ b/lib/symbolic/preset.cpp @@ -0,0 +1,78 @@ +#include "isaac/symbolic/preset.h" + +namespace isaac +{ + +namespace symbolic +{ + +namespace preset +{ + +void gemm::handle_node(array_expression::container_type &tree, size_t rootidx, args & a) +{ + //Matrix-Matrix product node + if(tree[rootidx].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY) + { + if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY) a.A = &tree[rootidx].lhs; + if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs; + switch(tree[rootidx].op.type) + { + case OPERATOR_MATRIX_PRODUCT_NN_TYPE: a.type = MATRIX_PRODUCT_NN_TYPE; break; + case OPERATOR_MATRIX_PRODUCT_NT_TYPE: a.type = MATRIX_PRODUCT_NT_TYPE; break; + case OPERATOR_MATRIX_PRODUCT_TN_TYPE: a.type = MATRIX_PRODUCT_TN_TYPE; break; + case OPERATOR_MATRIX_PRODUCT_TT_TYPE: a.type = MATRIX_PRODUCT_TT_TYPE; break; + default: break; + } + } + + //Scalar multiplication node + if(tree[rootidx].op.type==OPERATOR_MULT_TYPE) + { + //alpha*PROD + if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY + && tree[tree[rootidx].rhs.node_index].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY) + { + a.alpha = &tree[rootidx].lhs; + handle_node(tree, tree[rootidx].rhs.node_index, a); + } + + //beta*C + if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) + { + a.beta = &tree[rootidx].lhs; + a.C = &tree[rootidx].rhs; + } + } +} + +gemm::args gemm::check(array_expression::container_type & tree, size_t rootidx) +{ + lhs_rhs_element * assigned = &tree[rootidx].lhs; + gemm::args result ; + if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY) + { + rootidx = tree[rootidx].rhs.node_index; + //Form X + Y + if(tree[rootidx].op.type==OPERATOR_ADD_TYPE || tree[rootidx].op.type==OPERATOR_SUB_TYPE) + { + if(tree[rootidx].lhs.type_family==COMPOSITE_OPERATOR_FAMILY) + handle_node(tree, tree[rootidx].lhs.node_index, result); + if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY) + handle_node(tree, tree[rootidx].rhs.node_index, result); + } + else + handle_node(tree, rootidx, result); + } + if(result.C == NULL) + result.C = assigned; + else if(result.C->array != assigned->array) + result.C = NULL; + return result; +} + +} + +} + +}