GEMM: Incorporated K bounds checking inside kernel

This commit is contained in:
Philippe Tillet
2015-07-16 13:29:07 -04:00
parent 9de87da993
commit 1e3c853b58
4 changed files with 21 additions and 36 deletions

View File

@@ -176,12 +176,9 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
if(has_depth)
{
stream << "size_t gidz = " << GroupIdx2(backend) << ";" << std::endl;
stream << "size_t chunk_size = K/" << p_.depth << ";" << std::endl;
stream << "size_t offz = chunk_size*gidz;" << std::endl;
}
else
{
stream << "size_t chunk_size = K;" << std::endl;
stream << "size_t div = (K+" << p_.depth-1 << ")/" << p_.depth << ";" << std::endl;
stream << "size_t offz = div*gidz;" << std::endl;
stream << "K = min(K - div*gidz, div);" << std::endl;
}
stream << std::endl;
@@ -190,8 +187,6 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "size_t idyT = idt / " << p_.local_fetch_0 << ";" << std::endl;
stream << std::endl;
if (A_trans_=='N')
stream << "A += (idxT*" << p_.simd_width << " + gidx*" << p_.mL<< ")" << ASTRIDE1 << " + idyT*lda" << (has_depth?"+ offz*lda":"") << ";" << std::endl;
else
@@ -203,6 +198,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "B += idxT*" << p_.simd_width << BSTRIDE1 << " + (idyT + gidy*" << p_.nL << ")*ldb" << (has_depth?"+ offz":"") << ";" << std::endl;
stream << "for(unsigned int i = 0 ; i < " << npA << " ; ++i) Ai[i] = A;" << std::endl;
stream << "for(unsigned int i = 0 ; i < " << npB << " ; ++i) Bi[i] = B;" << std::endl;
for(unsigned int i = 0 ; i < npA ; i++ )
if (A_trans_=='N')
@@ -210,10 +206,6 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
else
stream << "if(gidx*" << p_.mL << " + idyT + " << i << "*" << p_.local_fetch_1 << " < M) Ai[" << i << "] += " << i*p_.local_fetch_1 << "*lda;" << std::endl;
stream << "for(unsigned int i = 0 ; i < " << npB << " ; ++i) Bi[i] = B;" << std::endl;
for(unsigned int i = 0 ; i < npB ; i++ )
if (B_trans_=='T')
stream << "if(gidy*" << p_.nL << " + idxT* " << p_.simd_width << " + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << BSTRIDE1 << ";" << std::endl;
@@ -224,11 +216,10 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << LocalPtr(backend) << " " << sdtype << "* lBstore = lB + idyT*" << lldb << " + idxT*" << p_.simd_width << ";" << std::endl;
stream << "//Outer loop" << std::endl;
stream << "for(size_t block_k=0; block_k < chunk_size ; block_k+=" << p_.kL << "){" << std::endl;
stream << "for(long block_k=K; block_k > 0 ; block_k-=" << p_.kL << "){" << std::endl;
stream.inc_tab();
stream << LocalBarrier(backend) << ";" << std::endl;
stream << "//Fetch A to local memory" << std::endl;
if (A_trans_=='N')
{
@@ -238,8 +229,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k);
string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]");
if(check_bounds_)
to_load = "(block_k + idyT + " + kk + "< chunk_size)?" + to_load + ":0";
to_load = "(idyT + " + kk + "< block_k)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lAstore + " + to_string(k*llda+m)) << ";" << std::endl;
}
}
@@ -251,8 +241,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string mm = to_string(m/p_.local_fetch_1);
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]");
if(check_bounds_)
to_load = "(block_k + idxT + " + kk + "< chunk_size)?" + to_load + ":0";
to_load = "(idxT + " + kk + "< block_k)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lAstore + " + to_string(m*llda+k)) << ";" << std::endl;
}
}
@@ -266,8 +255,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*ldb]");
if(check_bounds_)
to_load = "(block_k + idyT + " + kk + "< chunk_size)?" + to_load + ":0";
to_load = "(idyT + " + kk + "< block_k)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lBstore + " + to_string(k*lldb+n)) << ";" << std::endl;
}
}
@@ -279,8 +267,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
std::string nn = to_string(n/p_.local_fetch_1);
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]");
if(check_bounds_)
to_load = "(block_k + idxT + " + kk + "< chunk_size)?" + to_load + ":0";
to_load = "(idxT + " + kk + "< block_k)?" + to_load + ":0";
stream << VSTORE(to_load, "0", "lBstore + " + to_string(n*lldb+k)) << ";" << std::endl;
}
}
@@ -457,9 +444,10 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
value_scalar const & alpha, value_scalar const & beta,
driver::Program & program, const char * suffix, execution_options_type const & options)
{
if(M==0 || N==0 || K==0)
return;
using tools::align;
if(M==0 || N==0 || K==0)
return;
char gemm_name[32] = {"gemm"};
char reduce_name[32] = {"reduce"};
@@ -478,8 +466,6 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
driver::Kernel gemm(program, gemm_name);
driver::NDRange local(p_.local_size_0, p_.local_size_1);
using tools::align;
driver::NDRange global(align(align(M,p_.mS)/p_.mS, p_.local_size_0), align(align(N,p_.nS)/p_.nS, p_.local_size_1), p_.depth);
unsigned int current_arg = 0;
@@ -611,17 +597,16 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
execution_options_type const & options = ctr.execution_options();
int_t lK = K / (p_.kL*p_.depth) * p_.kL*p_.depth;
if (lK==0 || ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1)
if (ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1)
{
fallback.enqueue_block(queue, M, N, K, *pA, *pB, *pC, alpha, beta, program, "fallback", options);
}
else
{
// std::cout << p_.local_size_0 << " " << p_.kL << " " << p_.local_size_1 << " " << p_.depth << std::endl;
value_scalar _1(1, dtype);
enqueue_block(queue, M, N, lK, create_slice(*pA, 0, M, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, beta, program, suffix, options);
fallback.enqueue_block(queue, M, N, K - lK, create_slice(*pA, 0, M, lK, K, swap_A), create_slice(*pB, lK, K, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, _1, program, "fallback", options);
// value_scalar _1(1, dtype);
enqueue_block(queue, M, N, K, *pA, *pB, *pC, alpha, beta, program, suffix, options);
// fallback.enqueue_block(queue, M, N, K - lK, create_slice(*pA, 0, M, lK, K, swap_A), create_slice(*pB, lK, K, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, _1, program, "fallback", options);
}
}

View File

@@ -255,10 +255,10 @@ std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::
res[std::make_pair(GER_TYPE, DTYPE)] = ptr_t(new templates::ger(1,8,8,8,8,templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(GEMV_N_TYPE, DTYPE)] = ptr_t(new templates::gemv_n(1, 8, 8, 4, 16, templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(GEMV_T_TYPE, DTYPE)] = ptr_t(new templates::gemv_t(1, 8, 8, 64, 8, templates::FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(GEMM_NN_TYPE, DTYPE)] = ptr_t(new templates::gemm_nn(1, 8, 32, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(GEMM_TN_TYPE, DTYPE)] = ptr_t(new templates::gemm_tn(1, 8, 32, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(GEMM_NN_TYPE, DTYPE)] = ptr_t(new templates::gemm_nn(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(GEMM_TN_TYPE, DTYPE)] = ptr_t(new templates::gemm_tn(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(GEMM_NT_TYPE, DTYPE)] = ptr_t(new templates::gemm_nt(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(GEMM_TT_TYPE, DTYPE)] = ptr_t(new templates::gemm_tt(1, 8, 32, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(GEMM_TT_TYPE, DTYPE)] = ptr_t(new templates::gemm_tt(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
}
return res;
}

View File

@@ -115,7 +115,7 @@ def main():
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
#Source files
src = 'src/lib/symbolic/preset.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/backend/templates/gemv.cpp src/lib/backend/templates/axpy.cpp src/lib/backend/templates/gemm.cpp src/lib/backend/templates/ger.cpp src/lib/backend/templates/dot.cpp src/lib/backend/templates/base.cpp src/lib/backend/mapped_object.cpp src/lib/backend/stream.cpp src/lib/backend/parse.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/driver/backend.cpp src/lib/driver/device.cpp src/lib/driver/kernel.cpp src/lib/driver/buffer.cpp src/lib/driver/platform.cpp src/lib/driver/check.cpp src/lib/driver/program.cpp src/lib/driver/command_queue.cpp src/lib/driver/context.cpp src/lib/driver/event.cpp src/lib/driver/ndrange.cpp src/lib/driver/handle.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/wrap/clBLAS.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
src = 'src/lib/array.cpp src/lib/value_scalar.cpp src/lib/wrap/clBLAS.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/kernel.cpp src/lib/driver/event.cpp src/lib/driver/command_queue.cpp src/lib/driver/program.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/handle.cpp src/lib/driver/device.cpp src/lib/driver/context.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/backend.cpp src/lib/backend/parse.cpp src/lib/backend/mapped_object.cpp src/lib/backend/stream.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp src/lib/backend/templates/gemm.cpp src/lib/backend/templates/ger.cpp src/lib/backend/templates/gemv.cpp src/lib/backend/templates/dot.cpp src/lib/backend/templates/base.cpp src/lib/backend/templates/axpy.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
boostsrc = 'external/boost/libs/'
for s in ['numpy','python','smart_ptr','system','thread']:
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]

View File

@@ -98,7 +98,7 @@ template<typename T>
void test_impl(T epsilon, ad::driver::Context const & ctx)
{
int_t M = 173;
int_t N = 233;
int_t N = 256;
int_t K = 293;
int_t SUBM = 64;