GEMM: Incorporated K bounds checking inside kernel

This commit is contained in:
Philippe Tillet
2015-07-16 13:29:07 -04:00
parent 9de87da993
commit 1e3c853b58
4 changed files with 21 additions and 36 deletions

View File

@@ -176,12 +176,9 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
if(has_depth)
{
stream << "size_t gidz = " << GroupIdx2(backend) << ";" << std::endl;
stream << "size_t chunk_size = K/" << p_.depth << ";" << std::endl;
stream << "size_t offz = chunk_size*gidz;" << std::endl;
}
else
{
stream << "size_t chunk_size = K;" << std::endl;
stream << "size_t div = (K+" << p_.depth-1 << ")/" << p_.depth << ";" << std::endl;
stream << "size_t offz = div*gidz;" << std::endl;
stream << "K = min(K - div*gidz, div);" << std::endl;
}
stream << std::endl;
@@ -190,8 +187,6 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "size_t idyT = idt / " << p_.local_fetch_0 << ";" << std::endl;
stream << std::endl;
if (A_trans_=='N')
stream << "A += (idxT*" << p_.simd_width << " + gidx*" << p_.mL<< ")" << ASTRIDE1 << " + idyT*lda" << (has_depth?"+ offz*lda":"") << ";" << std::endl;
else
@@ -203,6 +198,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "B += idxT*" << p_.simd_width << BSTRIDE1 << " + (idyT + gidy*" << p_.nL << ")*ldb" << (has_depth?"+ offz":"") << ";" << std::endl;
stream << "for(unsigned int i = 0 ; i < " << npA << " ; ++i) Ai[i] = A;" << std::endl;
stream << "for(unsigned int i = 0 ; i < " << npB << " ; ++i) Bi[i] = B;" << std::endl;
for(unsigned int i = 0 ; i < npA ; i++ )
if (A_trans_=='N')
@@ -210,10 +206,6 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
else
stream << "if(gidx*" << p_.mL << " + idyT + " << i << "*" << p_.local_fetch_1 << " < M) Ai[" << i << "] += " << i*p_.local_fetch_1 << "*lda;" << std::endl;
stream << "for(unsigned int i = 0 ; i < " << npB << " ; ++i) Bi[i] = B;" << std::endl;
for(unsigned int i = 0 ; i < npB ; i++ )
if (B_trans_=='T')
stream << "if(gidy*" << p_.nL << " + idxT* " << p_.simd_width << " + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << BSTRIDE1 << ";" << std::endl;
@@ -224,11 +216,10 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << LocalPtr(backend) << " " << sdtype << "* lBstore = lB + idyT*" << lldb << " + idxT*" << p_.simd_width << ";" << std::endl;
stream << "//Outer loop" << std::endl;
stream << "for(size_t block_k=0; block_k < chunk_size ; block_k+=" << p_.kL << "){" << std::endl;
stream << "for(long block_k=K; block_k > 0 ; block_k-=" << p_.kL << "){" << std::endl;
stream.inc_tab();
stream << LocalBarrier(backend) << ";" << std::endl;
stream << "//Fetch A to local memory" << std::endl;
if (A_trans_=='N')
{
@@ -238,8 +229,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k);
string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]");
if(check_bounds_)
to_load = "(block_k + idyT + " + kk + "< chunk_size)?" + to_load + ":0";
to_load = "(idyT + " + kk + "< block_k)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lAstore + " + to_string(k*llda+m)) << ";" << std::endl;
}
}
@@ -251,8 +241,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string mm = to_string(m/p_.local_fetch_1);
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]");
if(check_bounds_)
to_load = "(block_k + idxT + " + kk + "< chunk_size)?" + to_load + ":0";
to_load = "(idxT + " + kk + "< block_k)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lAstore + " + to_string(m*llda+k)) << ";" << std::endl;
}
}
@@ -266,8 +255,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*ldb]");
if(check_bounds_)
to_load = "(block_k + idyT + " + kk + "< chunk_size)?" + to_load + ":0";
to_load = "(idyT + " + kk + "< block_k)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lBstore + " + to_string(k*lldb+n)) << ";" << std::endl;
}
}
@@ -279,8 +267,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string nn = to_string(n/p_.local_fetch_1);
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]");
if(check_bounds_)
to_load = "(block_k + idxT + " + kk + "< chunk_size)?" + to_load + ":0";
to_load = "(idxT + " + kk + "< block_k)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lBstore + " + to_string(n*lldb+k)) << ";" << std::endl;
}
}
@@ -457,9 +444,10 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
value_scalar const & alpha, value_scalar const & beta,
driver::Program & program, const char * suffix, execution_options_type const & options)
{
if(M==0 || N==0 || K==0)
return;
using tools::align;
if(M==0 || N==0 || K==0)
return;
char gemm_name[32] = {"gemm"};
char reduce_name[32] = {"reduce"};
@@ -478,8 +466,6 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
driver::Kernel gemm(program, gemm_name);
driver::NDRange local(p_.local_size_0, p_.local_size_1);
using tools::align;
driver::NDRange global(align(align(M,p_.mS)/p_.mS, p_.local_size_0), align(align(N,p_.nS)/p_.nS, p_.local_size_1), p_.depth);
unsigned int current_arg = 0;
@@ -611,17 +597,16 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
execution_options_type const & options = ctr.execution_options();
int_t lK = K / (p_.kL*p_.depth) * p_.kL*p_.depth;
if (lK==0 || ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1)
if (ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1)
{
fallback.enqueue_block(queue, M, N, K, *pA, *pB, *pC, alpha, beta, program, "fallback", options);
}
else
{
// std::cout << p_.local_size_0 << " " << p_.kL << " " << p_.local_size_1 << " " << p_.depth << std::endl;
value_scalar _1(1, dtype);
enqueue_block(queue, M, N, lK, create_slice(*pA, 0, M, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, beta, program, suffix, options);
fallback.enqueue_block(queue, M, N, K - lK, create_slice(*pA, 0, M, lK, K, swap_A), create_slice(*pB, lK, K, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, _1, program, "fallback", options);
// value_scalar _1(1, dtype);
enqueue_block(queue, M, N, K, *pA, *pB, *pC, alpha, beta, program, suffix, options);
// fallback.enqueue_block(queue, M, N, K - lK, create_slice(*pA, 0, M, lK, K, swap_A), create_slice(*pB, lK, K, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, _1, program, "fallback", options);
}
}