Backend: Now not creating a temporary upon C = alpha*dot(op(A), op(B)) + beta*C

This commit is contained in:
Philippe Tillet
2015-06-27 17:55:01 -07:00
parent 3525edd54c
commit 0e207e7ca4
7 changed files with 206 additions and 67 deletions

View File

@@ -1,10 +1,12 @@
#include "isaac/array.h"
#include "isaac/backend/templates/mproduct.h"
#include "isaac/backend/keywords.h"
#include "isaac/model/model.h"
#include "isaac/symbolic/preset.h"
#include "isaac/tools/make_vector.hpp"
#include "isaac/tools/to_string.hpp"
#include "isaac/tools/miscellaneous.hpp"
#include "isaac/model/model.h"
namespace isaac
{
@@ -640,17 +642,15 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return array(M, s0, s1);
}
std::vector<int_t> mproduct::infos(expressions_tuple const & expressions, lhs_rhs_element *& C, lhs_rhs_element *& A, lhs_rhs_element *& B, lhs_rhs_element *&, lhs_rhs_element *&)
std::vector<int_t> mproduct::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments)
{
isaac::array_expression & array_expression = (*expressions.data().front());
array_expression::container_type & array = array_expression.tree();
std::size_t root = array_expression.root();
C = &array[root].lhs;
A = &array[array[root].rhs.node_index].lhs;
B = &array[array[root].rhs.node_index].rhs;
int_t M = C->array->shape()[0];
int_t N = C->array->shape()[1];
int_t K = (A_trans_=='T')?A->array->shape()[0]:A->array->shape()[1];
arguments = symbolic::preset::gemm::check(array, root);
int_t M = arguments.C->array->shape()[0];
int_t N = arguments.C->array->shape()[1];
int_t K = (A_trans_=='T')?arguments.A->array->shape()[0]:arguments.A->array->shape()[1];
return {M, N, K};
}
@@ -665,8 +665,8 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
std::vector<int_t> mproduct::input_sizes(expressions_tuple const & expressions)
{
lhs_rhs_element *d0, *d1, *d2, *d3, *d4;
return infos(expressions, d0, d1, d2, d3, d4);
symbolic::preset::gemm::args dummy;
return infos(expressions, dummy);
}
void mproduct::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
@@ -677,8 +677,8 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
expressions_tuple const & expressions = ctr.x();
lhs_rhs_element * eC, *eA, *eB, *ealpha, *ebeta;
std::vector<int_t> MNK = infos(expressions, eC, eA, eB, ealpha, ebeta);
symbolic::preset::gemm::args args;
std::vector<int_t> MNK = infos(expressions, args);
int_t M = MNK[0];
int_t N = MNK[1];
@@ -688,9 +688,9 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
if(M==0 || N == 0 || K ==0)
return;
//Extract
array * pA = eA->array;
array * pB = eB->array;
array * pC = eC->array;
array * pA = args.A->array;
array * pB = args.B->array;
array * pC = args.C->array;
//Check if requires fallback
int_t ldstrideA = pA->stride()[0];
@@ -699,7 +699,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
int_t ldstartA = pA->start()[0];
int_t ldstartB = pB->start()[0];
numeric_type dtype = eC->dtype;
numeric_type dtype = args.C->dtype;
//Enqueue
bool swap_A = (A_trans_=='T');
@@ -708,7 +708,6 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
value_scalar beta(0, dtype);
value_scalar alpha(1, dtype);
value_scalar _1(1, dtype);
execution_options_type const & options = ctr.execution_options();
@@ -724,6 +723,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
int_t lM = M / p_.mL * p_.mL;
int_t lN = N / p_.nL * p_.nL;
int_t lK = K / (p_.kL*p_.depth) * p_.kL*p_.depth;
value_scalar _1(1, dtype);
enqueue_block(queue, lM, lN, lK, create_slice(*pA, 0, lM, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, beta, program, suffix, options);
fallback.enqueue_block(queue, lM, lN, K - lK, create_slice(*pA, 0, lM, lK, K, swap_A), create_slice(*pB, lK, K, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, _1, program, "fallback", options);