JIT: No longer using fallbacks for stride[0] > 1

It was pretty messy.
This commit is contained in:
Philippe Tillet
2016-04-10 16:31:29 -04:00
parent 81139e0642
commit 1e439ad5bc
20 changed files with 5232 additions and 113 deletions

View File

@@ -132,11 +132,12 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
#define VLOAD(offset, ptr) vload(p_.simd_width, sdtype, offset, ptr, "1", backend, true)
#define VLOAD_MISALIGNED(offset, ptr) vload(p_.simd_width, sdtype, offset, ptr, "1", backend, false)
#define VSTORE(value, offset, ptr) vstore(p_.simd_width, sdtype, value, offset, ptr, "1", backend)
#define ASTRIDE1 string(check_bounds_?"*Astride1":"")
#define BSTRIDE1 string(check_bounds_?"*Bstride1":"")
#define CSTRIDE1 string(check_bounds_?"*Cstride1":"")
symbolic::preset::matrix_product::args args;
infos(tree, args);
std::string ASTRIDE1 = (args.A->ld[0] > 1)?"*Astride1":"";
std::string BSTRIDE1 = (args.B->ld[0] > 1)?"*Bstride1":"";
std::string CSTRIDE1 = (args.C->ld[0] > 1)?"*Cstride1":"";
//////////////////
/// INIT
@@ -681,7 +682,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
return {M, N, K};
}
matrix_product::matrix_product(matrix_product_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<matrix_product, matrix_product_parameters>(parameters, FUSE_INDEPENDENT), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
matrix_product::matrix_product(matrix_product_parameters const & parameters, char A_trans, char B_trans) : base_impl<matrix_product, matrix_product_parameters>(parameters, FUSE_INDEPENDENT), A_trans_(A_trans), B_trans_(B_trans)
{
if(A_trans_=='N' && B_trans_=='N') type_ = MATRIX_PRODUCT_NN;
else if(A_trans_=='T' && B_trans_=='N') type_ = MATRIX_PRODUCT_TN;
@@ -696,14 +697,9 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
return infos((expression_tree&)expressions, dummy);
}
void matrix_product::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback_base, runtime::execution_handler const & control)
void matrix_product::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & control)
{
using namespace tools;
matrix_product & fallback = (matrix_product&)fallback_base;
expression_tree const & expressions = control.x();
symbolic::preset::matrix_product::args args;
std::vector<int_t> MNK = infos(expressions, args);
int_t M = MNK[0];
@@ -714,10 +710,7 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
return;
//Enqueue
runtime::execution_options_type const & options = control.execution_options();
if (args.A->ld[0] > 1 || args.B->ld[0] > 1 || args.C->ld[0] > 1)
fallback.enqueue_block(queue, M, N, K, *args.A, *args.B, *args.C, args.alpha, args.beta, program, "fallback", options);
else
enqueue_block(queue, M, N, K, *args.A, *args.B, *args.C, args.alpha, args.beta, program, suffix, options);
enqueue_block(queue, M, N, K, *args.A, *args.B, *args.C, args.alpha, args.beta, program, suffix, options);
}
//
@@ -725,8 +718,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) :
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N')
, int_t lfetch0, int_t lfetch1) :
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'N', 'N')
{
}
@@ -735,8 +728,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) :
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N')
, int_t lfetch0, int_t lfetch1) :
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'T', 'N')
{ }
//
@@ -744,8 +737,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) :
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T')
, int_t lfetch0, int_t lfetch1) :
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'N', 'T')
{ }
//
@@ -753,8 +746,8 @@ matrix_product_parameters::matrix_product_parameters(unsigned int simd_width
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) :
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T')
, int_t lfetch0, int_t lfetch1) :
matrix_product(matrix_product_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), 'T', 'T')
{ }
}