JIT: No longer using fallbacks for stride[0] > 1
It was pretty messy.
This commit is contained in:
@@ -132,11 +132,12 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
#define VLOAD(offset, ptr) vload(p_.simd_width, sdtype, offset, ptr, "1", backend, true)
|
||||
#define VLOAD_MISALIGNED(offset, ptr) vload(p_.simd_width, sdtype, offset, ptr, "1", backend, false)
|
||||
#define VSTORE(value, offset, ptr) vstore(p_.simd_width, sdtype, value, offset, ptr, "1", backend)
|
||||
#define ASTRIDE1 string(check_bounds_?"*Astride1":"")
|
||||
#define BSTRIDE1 string(check_bounds_?"*Bstride1":"")
|
||||
#define CSTRIDE1 string(check_bounds_?"*Cstride1":"")
|
||||
|
||||
|
||||
symbolic::preset::matrix_product::args args;
|
||||
infos(tree, args);
|
||||
std::string ASTRIDE1 = (args.A->ld[0] > 1)?"*Astride1":"";
|
||||
std::string BSTRIDE1 = (args.B->ld[0] > 1)?"*Bstride1":"";
|
||||
std::string CSTRIDE1 = (args.C->ld[0] > 1)?"*Cstride1":"";
|
||||
|
||||
//////////////////
|
||||
/// INIT
|
||||
@@ -681,7 +682,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
return {M, N, K};
|
||||
}
|
||||
|
||||
matrix_product::matrix_product(matrix_product_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<matrix_product, matrix_product_parameters>(parameters, FUSE_INDEPENDENT), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
|
||||
matrix_product::matrix_product(matrix_product_parameters const & parameters, char A_trans, char B_trans) : base_impl<matrix_product, matrix_product_parameters>(parameters, FUSE_INDEPENDENT), A_trans_(A_trans), B_trans_(B_trans)
|
||||
{
|
||||
if(A_trans_=='N' && B_trans_=='N') type_ = MATRIX_PRODUCT_NN;
|
||||
else if(A_trans_=='T' && B_trans_=='N') type_ = MATRIX_PRODUCT_TN;
|
||||
@@ -696,14 +697,9 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
return infos((expression_tree&)expressions, dummy);
|
||||
}
|
||||
|
||||
void matrix_product::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback_base, runtime::execution_handler const & control)
|
||||
void matrix_product::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & control)
|
||||
{
|
||||
using namespace tools;
|
||||
|
||||
matrix_product & fallback = (matrix_product&)fallback_base;
|
||||
expression_tree const & expressions = control.x();
|
||||
|
||||
|
||||
symbolic::preset::matrix_product::args args;
|
||||
std::vector<int_t> MNK = infos(expressions, args);
|
||||
int_t M = MNK[0];
|
||||
@@ -714,10 +710,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
return;
|
||||
//Enqueue
|
||||
runtime::execution_options_type const & options = control.execution_options();
|
||||
if (args.A->ld[0] > 1 || args.B->ld[0] > 1 || args.C->ld[0] > 1)
|
||||
fallback.enqueue_block(queue, M, N, K, *args.A, *args.B, *args.C, args.alpha, args.beta, program, "fallback", options);
|
||||
else
|
||||
enqueue_block(queue, M, N, K, *args.A, *args.B, *args.C, args.alpha, args.beta, program, suffix, options);
|
||||
enqueue_block(queue, M, N, K, *args.A, *args.B, *args.C, args.alpha, args.beta, program, suffix, options);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -725,8 +718,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N')
|
||||
, int_t lfetch0, int_t lfetch1) :
|
||||
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'N', 'N')
|
||||
{
|
||||
}
|
||||
|
||||
@@ -735,8 +728,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N')
|
||||
, int_t lfetch0, int_t lfetch1) :
|
||||
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'T', 'N')
|
||||
{ }
|
||||
|
||||
//
|
||||
@@ -744,8 +737,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T')
|
||||
, int_t lfetch0, int_t lfetch1) :
|
||||
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'N', 'T')
|
||||
{ }
|
||||
|
||||
//
|
||||
@@ -753,8 +746,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T')
|
||||
, int_t lfetch0, int_t lfetch1) :
|
||||
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'T', 'T')
|
||||
{ }
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user