Backend: Now not creating a temporary upon C = alpha*dot(op(A), op(B)) + beta*C
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
#include "isaac/array.h"
|
||||
#include "isaac/backend/templates/mproduct.h"
|
||||
#include "isaac/backend/keywords.h"
|
||||
#include "isaac/model/model.h"
|
||||
#include "isaac/symbolic/preset.h"
|
||||
#include "isaac/tools/make_vector.hpp"
|
||||
#include "isaac/tools/to_string.hpp"
|
||||
#include "isaac/tools/miscellaneous.hpp"
|
||||
#include "isaac/model/model.h"
|
||||
|
||||
namespace isaac
|
||||
{
|
||||
|
||||
@@ -640,17 +642,15 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
return array(M, s0, s1);
|
||||
}
|
||||
|
||||
std::vector<int_t> mproduct::infos(expressions_tuple const & expressions, lhs_rhs_element *& C, lhs_rhs_element *& A, lhs_rhs_element *& B, lhs_rhs_element *&, lhs_rhs_element *&)
|
||||
std::vector<int_t> mproduct::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments)
|
||||
{
|
||||
isaac::array_expression & array_expression = (*expressions.data().front());
|
||||
array_expression::container_type & array = array_expression.tree();
|
||||
std::size_t root = array_expression.root();
|
||||
C = &array[root].lhs;
|
||||
A = &array[array[root].rhs.node_index].lhs;
|
||||
B = &array[array[root].rhs.node_index].rhs;
|
||||
int_t M = C->array->shape()[0];
|
||||
int_t N = C->array->shape()[1];
|
||||
int_t K = (A_trans_=='T')?A->array->shape()[0]:A->array->shape()[1];
|
||||
arguments = symbolic::preset::gemm::check(array, root);
|
||||
int_t M = arguments.C->array->shape()[0];
|
||||
int_t N = arguments.C->array->shape()[1];
|
||||
int_t K = (A_trans_=='T')?arguments.A->array->shape()[0]:arguments.A->array->shape()[1];
|
||||
return {M, N, K};
|
||||
}
|
||||
|
||||
@@ -665,8 +665,8 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
|
||||
std::vector<int_t> mproduct::input_sizes(expressions_tuple const & expressions)
|
||||
{
|
||||
lhs_rhs_element *d0, *d1, *d2, *d3, *d4;
|
||||
return infos(expressions, d0, d1, d2, d3, d4);
|
||||
symbolic::preset::gemm::args dummy;
|
||||
return infos(expressions, dummy);
|
||||
}
|
||||
|
||||
void mproduct::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
|
||||
@@ -677,8 +677,8 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
expressions_tuple const & expressions = ctr.x();
|
||||
|
||||
|
||||
lhs_rhs_element * eC, *eA, *eB, *ealpha, *ebeta;
|
||||
std::vector<int_t> MNK = infos(expressions, eC, eA, eB, ealpha, ebeta);
|
||||
symbolic::preset::gemm::args args;
|
||||
std::vector<int_t> MNK = infos(expressions, args);
|
||||
|
||||
int_t M = MNK[0];
|
||||
int_t N = MNK[1];
|
||||
@@ -688,9 +688,9 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
if(M==0 || N == 0 || K ==0)
|
||||
return;
|
||||
//Extract
|
||||
array * pA = eA->array;
|
||||
array * pB = eB->array;
|
||||
array * pC = eC->array;
|
||||
array * pA = args.A->array;
|
||||
array * pB = args.B->array;
|
||||
array * pC = args.C->array;
|
||||
|
||||
//Check if requires fallback
|
||||
int_t ldstrideA = pA->stride()[0];
|
||||
@@ -699,7 +699,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
int_t ldstartA = pA->start()[0];
|
||||
int_t ldstartB = pB->start()[0];
|
||||
|
||||
numeric_type dtype = eC->dtype;
|
||||
numeric_type dtype = args.C->dtype;
|
||||
|
||||
//Enqueue
|
||||
bool swap_A = (A_trans_=='T');
|
||||
@@ -708,7 +708,6 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
|
||||
value_scalar beta(0, dtype);
|
||||
value_scalar alpha(1, dtype);
|
||||
value_scalar _1(1, dtype);
|
||||
|
||||
|
||||
execution_options_type const & options = ctr.execution_options();
|
||||
@@ -724,6 +723,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
int_t lM = M / p_.mL * p_.mL;
|
||||
int_t lN = N / p_.nL * p_.nL;
|
||||
int_t lK = K / (p_.kL*p_.depth) * p_.kL*p_.depth;
|
||||
value_scalar _1(1, dtype);
|
||||
|
||||
enqueue_block(queue, lM, lN, lK, create_slice(*pA, 0, lM, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, beta, program, suffix, options);
|
||||
fallback.enqueue_block(queue, lM, lN, K - lK, create_slice(*pA, 0, lM, lK, K, swap_A), create_slice(*pB, lK, K, 0, lN, swap_B), create_slice(*pC, 0, lM, 0, lN, false), alpha, _1, program, "fallback", options);
|
||||
|
Reference in New Issue
Block a user