Backend: Bugfix in GEMM bound-checking

This commit is contained in:
Philippe
2015-06-27 13:14:46 -04:00
parent 4cce9d3efd
commit 743a559f76
2 changed files with 8 additions and 5 deletions

View File

@@ -252,6 +252,9 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
stream.inc_tab();
stream << "in_bounds_B[n] = (gidy*" << p_.nL << " + " << (B_trans_=='T'?"idxT":"idyT") << " + n*" << fetch_size << ") < N;" << std::endl;
stream.dec_tab();
// for(unsigned int n = 0 ; n < p_.nL/fetch_size ; n++)
// stream << n>0?",":"" << "(gidy*" << p_.nL << " + " << (B_trans_=='T'?"idxT":"idyT") << " + " << n*fetch_size << ") < N";
}
}
@@ -340,7 +343,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
for(int_t k = 0; k < p_.mL; k += p_.local_fetch_1)
for(int_t m = 0; m < p_.kL; m += p_.local_fetch_0*p_.simd_width)
{
string in_bounds = "in_bounds_A[" + to_string(k/p_.local_fetch_1) + "]";
string in_bounds = "in_bounds_A[" + to_string(k/p_.local_fetch_1) + "] && (idxT + block_k < K)";
string to_load = "A[" + to_string(k) + "*Ald + " + to_string(m/p_.simd_width) + ASTRIDE1 + "]";
stream << VSTORE(HANDLE_BOUNDS(in_bounds, to_load), "0", "lAstore + lAstart + " + to_string(m*lAld+k)) << ";" << std::endl;
}
@@ -350,7 +353,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
for(int_t k = 0; k < p_.kL; k += p_.local_fetch_1)
for(int_t n = 0; n < p_.nL; n += p_.local_fetch_0*p_.simd_width)
{
string in_bounds = "in_bounds_B[" + to_string(n/(p_.local_fetch_0*p_.simd_width)) + "]";
string in_bounds = "in_bounds_B[" + to_string(n/(p_.local_fetch_0*p_.simd_width)) + "] && (idyT + block_k < K)";
string to_load = "B[" + to_string(k) + "*Bld + " + to_string(n/p_.simd_width) + BSTRIDE1 + "]";
stream << VSTORE(HANDLE_BOUNDS(in_bounds, to_load), "0", "lBstore + lBstart + " + to_string(k*lBld+n)) << ";" << std::endl;
}
@@ -358,7 +361,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
for(int_t k = 0; k < p_.nL; k += p_.local_fetch_1)
for(int_t n = 0; n < p_.kL; n += p_.local_fetch_0*p_.simd_width)
{
string in_bounds = "in_bounds_B[" + to_string(k/p_.local_fetch_1) + "]";
string in_bounds = "in_bounds_B[" + to_string(k/p_.local_fetch_1) + "] && (idxT + block_k < K)";
string to_load = "B[" + to_string(k) + "*Bld + " + to_string(n/p_.simd_width) + BSTRIDE1 + "]";
stream << VSTORE(HANDLE_BOUNDS(in_bounds, to_load), "0", "lBstore + lBstart + " + to_string(n*lBld+k)) << ";" << std::endl;
}

View File

@@ -256,8 +256,8 @@ std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<base> > ini
res[std::make_pair(MATRIX_AXPY_TYPE, DTYPE)] = ptr_t(new maxpy(1,8,8,8,8,FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(ROW_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new mreduction_rows(1, 8, 8, 4, 16, FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(COL_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new mreduction_cols(1, 8, 8, 64, 8, FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(MATRIX_PRODUCT_NN_TYPE, DTYPE)] = ptr_t(new mproduct_nn(1, 8, 8, 8, 1, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TN_TYPE, DTYPE)] = ptr_t(new mproduct_tn(1, 8, 8, 8, 1, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_NN_TYPE, DTYPE)] = ptr_t(new mproduct_nn(1, 8, 8, 8, 1, 1, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TN_TYPE, DTYPE)] = ptr_t(new mproduct_tn(1, 8, 8, 8, 1, 1, 1, 1, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_NT_TYPE, DTYPE)] = ptr_t(new mproduct_nt(1, 8, 8, 8, 1, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TT_TYPE, DTYPE)] = ptr_t(new mproduct_tt(1, 8, 8, 8, 1, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
}