|
|
|
@@ -132,11 +132,12 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
|
|
|
|
#define VLOAD(offset, ptr) vload(p_.simd_width, sdtype, offset, ptr, "1", backend, true)
|
|
|
|
|
#define VLOAD_MISALIGNED(offset, ptr) vload(p_.simd_width, sdtype, offset, ptr, "1", backend, false)
|
|
|
|
|
#define VSTORE(value, offset, ptr) vstore(p_.simd_width, sdtype, value, offset, ptr, "1", backend)
|
|
|
|
|
#define ASTRIDE1 string(check_bounds_?"*Astride1":"")
|
|
|
|
|
#define BSTRIDE1 string(check_bounds_?"*Bstride1":"")
|
|
|
|
|
#define CSTRIDE1 string(check_bounds_?"*Cstride1":"")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
symbolic::preset::matrix_product::args args;
|
|
|
|
|
infos(tree, args);
|
|
|
|
|
std::string ASTRIDE1 = (args.A->ld[0] > 1)?"*Astride1":"";
|
|
|
|
|
std::string BSTRIDE1 = (args.B->ld[0] > 1)?"*Bstride1":"";
|
|
|
|
|
std::string CSTRIDE1 = (args.C->ld[0] > 1)?"*Cstride1":"";
|
|
|
|
|
|
|
|
|
|
//////////////////
|
|
|
|
|
/// INIT
|
|
|
|
@@ -681,7 +682,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
|
|
|
|
return {M, N, K};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
matrix_product::matrix_product(matrix_product_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<matrix_product, matrix_product_parameters>(parameters, FUSE_INDEPENDENT), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
|
|
|
|
|
matrix_product::matrix_product(matrix_product_parameters const & parameters, char A_trans, char B_trans) : base_impl<matrix_product, matrix_product_parameters>(parameters, FUSE_INDEPENDENT), A_trans_(A_trans), B_trans_(B_trans)
|
|
|
|
|
{
|
|
|
|
|
if(A_trans_=='N' && B_trans_=='N') type_ = MATRIX_PRODUCT_NN;
|
|
|
|
|
else if(A_trans_=='T' && B_trans_=='N') type_ = MATRIX_PRODUCT_TN;
|
|
|
|
@@ -696,14 +697,9 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
|
|
|
|
return infos((expression_tree&)expressions, dummy);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void matrix_product::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback_base, runtime::execution_handler const & control)
|
|
|
|
|
void matrix_product::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & control)
|
|
|
|
|
{
|
|
|
|
|
using namespace tools;
|
|
|
|
|
|
|
|
|
|
matrix_product & fallback = (matrix_product&)fallback_base;
|
|
|
|
|
expression_tree const & expressions = control.x();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
symbolic::preset::matrix_product::args args;
|
|
|
|
|
std::vector<int_t> MNK = infos(expressions, args);
|
|
|
|
|
int_t M = MNK[0];
|
|
|
|
@@ -714,10 +710,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
|
|
|
|
return;
|
|
|
|
|
//Enqueue
|
|
|
|
|
runtime::execution_options_type const & options = control.execution_options();
|
|
|
|
|
if (args.A->ld[0] > 1 || args.B->ld[0] > 1 || args.C->ld[0] > 1)
|
|
|
|
|
fallback.enqueue_block(queue, M, N, K, *args.A, *args.B, *args.C, args.alpha, args.beta, program, "fallback", options);
|
|
|
|
|
else
|
|
|
|
|
enqueue_block(queue, M, N, K, *args.A, *args.B, *args.C, args.alpha, args.beta, program, suffix, options);
|
|
|
|
|
enqueue_block(queue, M, N, K, *args.A, *args.B, *args.C, args.alpha, args.beta, program, suffix, options);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
@@ -725,8 +718,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
|
|
|
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
|
|
|
|
, int_t ms, int_t ks, int_t ns
|
|
|
|
|
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
|
|
|
|
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
|
|
|
|
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N')
|
|
|
|
|
, int_t lfetch0, int_t lfetch1) :
|
|
|
|
|
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'N', 'N')
|
|
|
|
|
{
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -735,8 +728,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
|
|
|
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
|
|
|
|
, int_t ms, int_t ks, int_t ns
|
|
|
|
|
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
|
|
|
|
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
|
|
|
|
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N')
|
|
|
|
|
, int_t lfetch0, int_t lfetch1) :
|
|
|
|
|
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'T', 'N')
|
|
|
|
|
{ }
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
@@ -744,8 +737,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
|
|
|
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
|
|
|
|
, int_t ms, int_t ks, int_t ns
|
|
|
|
|
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
|
|
|
|
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
|
|
|
|
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T')
|
|
|
|
|
, int_t lfetch0, int_t lfetch1) :
|
|
|
|
|
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'N', 'T')
|
|
|
|
|
{ }
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
@@ -753,8 +746,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
|
|
|
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
|
|
|
|
, int_t ms, int_t ks, int_t ns
|
|
|
|
|
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
|
|
|
|
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
|
|
|
|
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T')
|
|
|
|
|
, int_t lfetch0, int_t lfetch1) :
|
|
|
|
|
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'T', 'T')
|
|
|
|
|
{ }
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|