Backend: Now not creating a temporary upon C = alpha*dot(op(A), op(B)) + beta*C

This commit is contained in:
Philippe Tillet
2015-06-27 17:55:01 -07:00
parent 3525edd54c
commit 0e207e7ca4
7 changed files with 206 additions and 67 deletions

View File

@@ -2,6 +2,8 @@
#define ISAAC_BACKEND_TEMPLATES_MPRODUCT_H #define ISAAC_BACKEND_TEMPLATES_MPRODUCT_H
#include "isaac/backend/templates/base.h" #include "isaac/backend/templates/base.h"
#include "isaac/symbolic/expression.h"
#include "isaac/symbolic/preset.h"
namespace isaac namespace isaac
{ {
@@ -46,7 +48,7 @@ private:
void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, array const & A, array const & B, array const & C, void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, array const & A, array const & B, array const & C,
value_scalar const &alpha, value_scalar const &beta, driver::Program & program, const char * suffix, execution_options_type const & options); value_scalar const &alpha, value_scalar const &beta, driver::Program & program, const char * suffix, execution_options_type const & options);
array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap); array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap);
std::vector<int_t> infos(expressions_tuple const & expressions, lhs_rhs_element *&C, lhs_rhs_element *&A, lhs_rhs_element *&B, lhs_rhs_element *&alpha, lhs_rhs_element *&beta); std::vector<int_t> infos(expressions_tuple const & expressions, isaac::symbolic::preset::gemm::args &arguments);
public: public:
mproduct(mproduct::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans); mproduct(mproduct::parameters_type const & parameters, bool check_bound, char A_trans, char B_trans);
std::vector<int_t> input_sizes(expressions_tuple const & expressions); std::vector<int_t> input_sizes(expressions_tuple const & expressions);

View File

@@ -0,0 +1,49 @@
#ifndef ISAAC_SYMBOLIC_PRESET_H_
#define ISAAC_SYMBOLIC_PRESET_H_
#include "isaac/symbolic/expression.h"
namespace isaac
{
namespace symbolic
{
namespace preset
{
class gemm
{
public:
struct args
{
args(): alpha(NULL), A(NULL), B(NULL), beta(NULL), C(NULL), type(INVALID_EXPRESSION_TYPE){ }
lhs_rhs_element* alpha;
lhs_rhs_element* A;
lhs_rhs_element* B;
lhs_rhs_element* beta;
lhs_rhs_element* C;
expression_type type;
operator bool() const
{
return type!=INVALID_EXPRESSION_TYPE && C!=NULL;
}
};
private:
static void handle_node(array_expression::container_type &tree, size_t rootidx, args & a);
public:
static args check(array_expression::container_type & tree, size_t rootidx);
};
}
}
}
#endif

View File

@@ -127,6 +127,7 @@ template<> struct to_numeric_type<double> { static const numeric_type value = DO
enum expression_type enum expression_type
{ {
INVALID_EXPRESSION_TYPE,
SCALAR_AXPY_TYPE, SCALAR_AXPY_TYPE,
VECTOR_AXPY_TYPE, VECTOR_AXPY_TYPE,
MATRIX_AXPY_TYPE, MATRIX_AXPY_TYPE,
@@ -136,8 +137,7 @@ enum expression_type
MATRIX_PRODUCT_NN_TYPE, MATRIX_PRODUCT_NN_TYPE,
MATRIX_PRODUCT_TN_TYPE, MATRIX_PRODUCT_TN_TYPE,
MATRIX_PRODUCT_NT_TYPE, MATRIX_PRODUCT_NT_TYPE,
MATRIX_PRODUCT_TT_TYPE, MATRIX_PRODUCT_TT_TYPE
INVALID_EXPRESSION_TYPE
}; };
struct slice struct slice

View File

@@ -1,10 +1,12 @@
#include "isaac/array.h" #include "isaac/array.h"
#include "isaac/backend/templates/mproduct.h" #include "isaac/backend/templates/mproduct.h"
#include "isaac/backend/keywords.h" #include "isaac/backend/keywords.h"
#include "isaac/model/model.h"
#include "isaac/symbolic/preset.h"
#include "isaac/tools/make_vector.hpp" #include "isaac/tools/make_vector.hpp"
#include "isaac/tools/to_string.hpp" #include "isaac/tools/to_string.hpp"
#include "isaac/tools/miscellaneous.hpp" #include "isaac/tools/miscellaneous.hpp"
#include "isaac/model/model.h"
namespace isaac namespace isaac
{ {
@@ -640,17 +642,15 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return array(M, s0, s1); return array(M, s0, s1);
} }
std::vector<int_t> mproduct::infos(expressions_tuple const & expressions, lhs_rhs_element *& C, lhs_rhs_element *& A, lhs_rhs_element *& B, lhs_rhs_element *&, lhs_rhs_element *&) std::vector<int_t> mproduct::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments)
{ {
isaac::array_expression & array_expression = (*expressions.data().front()); isaac::array_expression & array_expression = (*expressions.data().front());
array_expression::container_type & array = array_expression.tree(); array_expression::container_type & array = array_expression.tree();
std::size_t root = array_expression.root(); std::size_t root = array_expression.root();
C = &array[root].lhs; arguments = symbolic::preset::gemm::check(array, root);
A = &array[array[root].rhs.node_index].lhs; int_t M = arguments.C->array->shape()[0];
B = &array[array[root].rhs.node_index].rhs; int_t N = arguments.C->array->shape()[1];
int_t M = C->array->shape()[0]; int_t K = (A_trans_=='T')?arguments.A->array->shape()[0]:arguments.A->array->shape()[1];
int_t N = C->array->shape()[1];
int_t K = (A_trans_=='T')?A->array->shape()[0]:A->array->shape()[1];
return {M, N, K}; return {M, N, K};
} }
@@ -665,8 +665,8 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
std::vector<int_t> mproduct::input_sizes(expressions_tuple const & expressions) std::vector<int_t> mproduct::input_sizes(expressions_tuple const & expressions)
{ {
lhs_rhs_element *d0, *d1, *d2, *d3, *d4; symbolic::preset::gemm::args dummy;
return infos(expressions, d0, d1, d2, d3, d4); return infos(expressions, dummy);
} }
void mproduct::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr) void mproduct::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
@@ -677,8 +677,8 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
expressions_tuple const & expressions = ctr.x(); expressions_tuple const & expressions = ctr.x();
lhs_rhs_element * eC, *eA, *eB, *ealpha, *ebeta; symbolic::preset::gemm::args args;
std::vector<int_t> MNK = infos(expressions, eC, eA, eB, ealpha, ebeta); std::vector<int_t> MNK = infos(expressions, args);
int_t M = MNK[0]; int_t M = MNK[0];
int_t N = MNK[1]; int_t N = MNK[1];
@@ -688,9 +688,9 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
if(M==0 || N == 0 || K ==0) if(M==0 || N == 0 || K ==0)
return; return;
//Extract //Extract
array * pA = eA->array; array * pA = args.A->array;
array * pB = eB->array; array * pB = args.B->array;
array * pC = eC->array; array * pC = args.C->array;
//Check if requires fallback //Check if requires fallback
int_t ldstrideA = pA->stride()[0]; int_t ldstrideA = pA->stride()[0];
@@ -699,7 +699,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
int_t ldstartA = pA->start()[0]; int_t ldstartA = pA->start()[0];
int_t ldstartB = pB->start()[0]; int_t ldstartB = pB->start()[0];
numeric_type dtype = eC->dtype; numeric_type dtype = args.C->dtype;
//Enqueue //Enqueue
bool swap_A = (A_trans_=='T'); bool swap_A = (A_trans_=='T');
@@ -708,7 +708,6 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
value_scalar beta(0, dtype); value_scalar beta(0, dtype);
value_scalar alpha(1, dtype); value_scalar alpha(1, dtype);
value_scalar _1(1, dtype);
execution_options_type const & options = ctr.execution_options(); execution_options_type const & options = ctr.execution_options();
@@ -724,6 +723,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
int_t lM = M / p_.mL * p_.mL; int_t lM = M / p_.mL * p_.mL;
int_t lN = N / p_.nL * p_.nL; int_t lN = N / p_.nL * p_.nL;
int_t lK = K / (p_.kL*p_.depth) * p_.kL*p_.depth; int_t lK = K / (p_.kL*p_.depth) * p_.kL*p_.depth;
value_scalar _1(1, dtype);
enqueue_block(queue, lM, lN, lK, create_slice(*pA, 0, lM, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, beta, program, suffix, options); enqueue_block(queue, lM, lN, lK, create_slice(*pA, 0, lM, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, beta, program, suffix, options);
fallback.enqueue_block(queue, lM, lN, K - lK, create_slice(*pA, 0, lM, lK, K, swap_A), create_slice(*pB, lK, K, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, _1, program, "fallback", options); fallback.enqueue_block(queue, lM, lN, K - lK, create_slice(*pA, 0, lM, lK, K, swap_A), create_slice(*pB, lK, K, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, _1, program, "fallback", options);

View File

@@ -7,6 +7,7 @@
#include <CL/cl.hpp> #include <CL/cl.hpp>
#include "isaac/model/model.h" #include "isaac/model/model.h"
#include "isaac/symbolic/expression.h" #include "isaac/symbolic/expression.h"
#include "isaac/symbolic/preset.h"
namespace isaac namespace isaac
{ {
@@ -157,7 +158,6 @@ namespace isaac
parse(array, node.rhs.node_index, next_type, breakpoints, final_type, false); parse(array, node.rhs.node_index, next_type, breakpoints, final_type, false);
} }
} }
} }
/** @brief Executes a array_expression on the given models map*/ /** @brief Executes a array_expression on the given models map*/
@@ -172,6 +172,14 @@ namespace isaac
//Todo: technically the datatype should be per temporary //Todo: technically the datatype should be per temporary
numeric_type dtype = root_save.lhs.dtype; numeric_type dtype = root_save.lhs.dtype;
expression_type final_type;
//GEMM
if(symbolic::preset::gemm::args args = symbolic::preset::gemm::check(tree, rootidx)){
final_type = args.type;
}
//Default
else
{
detail::breakpoints_t breakpoints; detail::breakpoints_t breakpoints;
breakpoints.reserve(8); breakpoints.reserve(8);
@@ -183,7 +191,7 @@ namespace isaac
current_type=VECTOR_AXPY_TYPE; current_type=VECTOR_AXPY_TYPE;
else else
current_type=MATRIX_AXPY_TYPE; current_type=MATRIX_AXPY_TYPE;
expression_type final_type = current_type; final_type = current_type;
/*----Parse required temporaries-----*/ /*----Parse required temporaries-----*/
detail::parse(tree, rootidx, current_type, breakpoints, final_type); detail::parse(tree, rootidx, current_type, breakpoints, final_type);
@@ -228,6 +236,7 @@ namespace isaac
//Incorporates the temporary within the array_expression //Incorporates the temporary within the array_expression
fill(*rit->second, (array&)*tmp); fill(*rit->second, (array&)*tmp);
} }
}
/*-----Compute final expression-----*/ /*-----Compute final expression-----*/
models[std::make_pair(final_type, dtype)]->execute(controller<expressions_tuple>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options())); models[std::make_pair(final_type, dtype)]->execute(controller<expressions_tuple>(expression, c.execution_options(), c.dispatcher_options(), c.compilation_options()));

View File

@@ -4,6 +4,7 @@
#include "isaac/value_scalar.h" #include "isaac/value_scalar.h"
#include <CL/cl.hpp> #include <CL/cl.hpp>
#include "isaac/symbolic/expression.h" #include "isaac/symbolic/expression.h"
#include "isaac/symbolic/preset.h"
namespace isaac namespace isaac
{ {

78
lib/symbolic/preset.cpp Normal file
View File

@@ -0,0 +1,78 @@
#include "isaac/symbolic/preset.h"
namespace isaac
{
namespace symbolic
{
namespace preset
{
void gemm::handle_node(array_expression::container_type &tree, size_t rootidx, args & a)
{
//Matrix-Matrix product node
if(tree[rootidx].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
{
if(tree[rootidx].lhs.type_family==ARRAY_TYPE_FAMILY) a.A = &tree[rootidx].lhs;
if(tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY) a.B = &tree[rootidx].rhs;
switch(tree[rootidx].op.type)
{
case OPERATOR_MATRIX_PRODUCT_NN_TYPE: a.type = MATRIX_PRODUCT_NN_TYPE; break;
case OPERATOR_MATRIX_PRODUCT_NT_TYPE: a.type = MATRIX_PRODUCT_NT_TYPE; break;
case OPERATOR_MATRIX_PRODUCT_TN_TYPE: a.type = MATRIX_PRODUCT_TN_TYPE; break;
case OPERATOR_MATRIX_PRODUCT_TT_TYPE: a.type = MATRIX_PRODUCT_TT_TYPE; break;
default: break;
}
}
//Scalar multiplication node
if(tree[rootidx].op.type==OPERATOR_MULT_TYPE)
{
//alpha*PROD
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY
&& tree[tree[rootidx].rhs.node_index].op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
{
a.alpha = &tree[rootidx].lhs;
handle_node(tree, tree[rootidx].rhs.node_index, a);
}
//beta*C
if(tree[rootidx].lhs.type_family==VALUE_TYPE_FAMILY && tree[rootidx].rhs.type_family==ARRAY_TYPE_FAMILY)
{
a.beta = &tree[rootidx].lhs;
a.C = &tree[rootidx].rhs;
}
}
}
gemm::args gemm::check(array_expression::container_type & tree, size_t rootidx)
{
lhs_rhs_element * assigned = &tree[rootidx].lhs;
gemm::args result ;
if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
{
rootidx = tree[rootidx].rhs.node_index;
//Form X + Y
if(tree[rootidx].op.type==OPERATOR_ADD_TYPE || tree[rootidx].op.type==OPERATOR_SUB_TYPE)
{
if(tree[rootidx].lhs.type_family==COMPOSITE_OPERATOR_FAMILY)
handle_node(tree, tree[rootidx].lhs.node_index, result);
if(tree[rootidx].rhs.type_family==COMPOSITE_OPERATOR_FAMILY)
handle_node(tree, tree[rootidx].rhs.node_index, result);
}
else
handle_node(tree, rootidx, result);
}
if(result.C == NULL)
result.C = assigned;
else if(result.C->array != assigned->array)
result.C = NULL;
return result;
}
}
}
}