GEMM: Safer bounds checking for K

This commit is contained in:
Philippe Tillet
2015-07-22 17:46:50 -07:00
parent 155554f5cf
commit 1cec0a9183
2 changed files with 377 additions and 345 deletions

View File

@@ -171,7 +171,6 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << _size_t << " idt;" << std::endl;
if(has_depth)
stream << _size_t << " gidz, div, offz;" << std::endl;
stream << "int Ky, Kx;" << std::endl;
stream << "A += offa;" << std::endl;
stream << "B += offb;" << std::endl;
@@ -200,7 +199,16 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "idT.x *= " << p_.simd_width << ";" << std::endl;
stream << "M -= ids.x;" << std::endl;
if(A_trans_=='N')
stream << "M -= idT.x;" << std::endl;
else
stream << "M -= idT.y;" << std::endl;
stream << "N -= ids.y;" << std::endl;
if(B_trans_=='T')
stream << "N -= idT.x;" << std::endl;
else
stream << "N -= idT.y;" << std::endl;
if (A_trans_=='N')
{
@@ -247,47 +255,30 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
for(unsigned int i = 0 ; i < npA ; i++ )
if (A_trans_=='N')
stream << "if(idT.x + " << i*p_.local_fetch_0*p_.simd_width << " < M) Ai[" << i << "] += (idT.x + " << i*p_.local_fetch_0*p_.simd_width << ")" << ASTRIDE1 << ";" << std::endl;
stream << "if( " << i*p_.local_fetch_0*p_.simd_width << " < M) Ai[" << i << "] += (idT.x + " << i*p_.local_fetch_0*p_.simd_width << ")" << ASTRIDE1 << ";" << std::endl;
else
stream << "if(idT.y + " << i*p_.local_fetch_1 << " < M) Ai[" << i << "] += (idT.y + " << i*p_.local_fetch_1 << ")*lda;" << std::endl;
stream << "if(" << i*p_.local_fetch_1 << " < M) Ai[" << i << "] += (idT.y + " << i*p_.local_fetch_1 << ")*lda;" << std::endl;
for(unsigned int i = 0 ; i < npB ; i++ )
if (B_trans_=='T')
stream << "if(idT.x + " << i*p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += (idT.x + " << i*p_.local_fetch_0*p_.simd_width << ")" << BSTRIDE1 << ";" << std::endl;
stream << "if(" << i*p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += (idT.x + " << i*p_.local_fetch_0*p_.simd_width << ")" << BSTRIDE1 << ";" << std::endl;
else
stream << "if(idT.y + " << i*p_.local_fetch_1 << " < N) Bi[" << i << "] += (idT.y + " << i*p_.local_fetch_1 << ")*ldb;" << std::endl;
stream << "if(" << i*p_.local_fetch_1 << " < N) Bi[" << i << "] += (idT.y + " << i*p_.local_fetch_1 << ")*ldb;" << std::endl;
stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
if(A_trans_=='N' || B_trans_=='T')
stream << "Ky = K - idT.y;" << std::endl;
if(A_trans_=='T' || B_trans_=='N')
stream << "Kx = K - idT.x;" << std::endl;
stream << "//Outer loop" << std::endl;
stream << "while(K > 0){" << std::endl;
stream << "while(K >=" << p_.kL << "){" << std::endl;
stream.inc_tab();
auto fetch_to_lds = [&](bool last_iteration)
{
stream << LocalBarrier(backend) << ";" << std::endl;
if(A_trans_=='N' || B_trans_=='T')
{
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_1)
stream << vint << " condy" << k << " = (" << vint << ")(" << k << ") < Ky;" << std::endl;
}
if(A_trans_=='T' || B_trans_=='N')
{
for(unsigned int k = 0 ; k < p_.kL ; k += p_.local_fetch_0*p_.simd_width)
{
stream << vint << " condx" << k << " = (" << vint << ")(";
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
stream << (s>0?",":"") << k + s;
stream << ") < Kx;" << std::endl;
}
}
stream << "//Fetch A to local memory" << std::endl;
if (A_trans_=='N')
{
@@ -296,9 +287,11 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
{
std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k);
string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]");
to_load = "(" + kk + " < Ky)?select((" + vdtype + ")0, " + to_load + ", condy" + kk + "):0";
stream << VSTORE(to_load, "0", "storeA + " + to_string(k*llda+m)) << ";" << std::endl;
if(last_iteration)
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
stream << "storeA[" << k*llda + m + s << "] = (condy" << k << " && " << s << "< M)? Ai[" << mm << "][" << k << "*lda + " << s << "] : 0;" << std::endl;
else
stream << VSTORE(VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]"), "0", "storeA + " + to_string(k*llda+m)) << ";" << std::endl;
}
}
else
@@ -308,9 +301,12 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
{
std::string mm = to_string(m/p_.local_fetch_1);
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]");
to_load = "(" + kk + " < Kx)?select((" + vdtype + ")0, " + to_load + ", condx" + kk + "):0";
stream << VSTORE(to_load, "0", "storeA + " + to_string(m*llda+k)) << ";" << std::endl;
if(last_iteration)
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
stream << "storeA[" << m*llda + k + s << "] = condx" << k + s << "? Ai[" << mm << "][" << k + s << ASTRIDE1 << "] : 0;" << std::endl;
else
stream << VSTORE(VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]"), "0", "storeA + " + to_string(m*llda+k)) << ";" << std::endl;
}
}
@@ -322,9 +318,11 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
{
std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0));
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*ldb]");
to_load = "(" + kk + " < Ky)?select((" + vdtype + ")0, " + to_load + ", condy" + kk + "):0";
stream << VSTORE(to_load, "0", "storeB + " + to_string(k*lldb+n)) << ";" << std::endl;
if(last_iteration)
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
stream << "storeB[" << k*lldb + n + s << "] = (condy" << k << " && " << s << "< N)? Bi[" << nn << "][" << kk << "*ldb +" << s << "] : 0;" << std::endl;
else
stream << VSTORE(VLOAD("0" ,"&Bi[" + nn +"][" + kk + "*ldb]"), "0", "storeB + " + to_string(k*lldb+n)) << ";" << std::endl;
}
}
else
@@ -334,9 +332,12 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
{
std::string nn = to_string(n/p_.local_fetch_1);
std::string kk = to_string(k);
string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]");
to_load = "(" + kk + " < Kx)?select((" + vdtype + ")0, " + to_load + ", condx" + kk + "):0";
stream << VSTORE(to_load, "0", "storeB + " + to_string(n*lldb+k)) << ";" << std::endl;
if(last_iteration)
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
stream << "storeB[" << n*lldb + k + s << "] = condx" << k + s << "? Bi[" << nn << "][" << k + s << BSTRIDE1 << "] : 0;" << std::endl;
else
stream << VSTORE(VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]"), "0", "storeB + " + to_string(n*lldb+k)) << ";" << std::endl;
}
}
@@ -352,7 +353,6 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << LocalBarrier(backend) << ";" << std::endl;
stream << "//Inner loop" << std::endl;
stream << "for(unsigned int k = 0; k < " << p_.kL << "; k+=" << p_.kS << "){" << std::endl;
stream.inc_tab();
@@ -419,12 +419,9 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream.dec_tab();
stream << "}" << std::endl;
stream << "K -= " << p_.kL << ";" << std::endl;
if(A_trans_=='N' || B_trans_=='T')
stream << "Ky -= " << p_.kL << ";" << std::endl;
if(A_trans_=='T' || B_trans_=='N')
stream << "Kx -= " << p_.kL << ";" << std::endl;
//Increment A pointers to global memory
if (A_trans_=='N')
@@ -442,11 +439,46 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
for(unsigned int i = 0 ; i < npB ; ++i)
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
};
fetch_to_lds(false);
stream.dec_tab();
stream << "}" << std::endl;
if(A_trans_=='N' || B_trans_=='T')
stream << "int Ky = K - idT.y;" << std::endl;
if(A_trans_=='T' || B_trans_=='N')
stream << "int Kx = K - idT.x;" << std::endl;
if(A_trans_=='N' || B_trans_=='T')
for(unsigned int k = 0; k < p_.kL; k += p_.local_fetch_1)
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
if(A_trans_=='T' || B_trans_=='N')
{
for(unsigned int k = 0 ; k < p_.kL ; k += p_.local_fetch_0*p_.simd_width)
for(unsigned int s = 0 ; s < p_.simd_width ; ++s)
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
}
fetch_to_lds(true);
stream << "//Write back C" << std::endl;
stream << "M += ids.x;" << std::endl;
if(A_trans_=='N')
stream << "M += idT.x;" << std::endl;
else
stream << "M += idT.y;" << std::endl;
if(B_trans_=='T')
stream << "N += idT.x;" << std::endl;
else
stream << "N += idT.y;" << std::endl;
stream << "N += ids.y;" << std::endl;
stream << _size_t << " offx = (ids.x + ids.z*" << p_.simd_width << ")" << ";" << std::endl;
stream << _size_t << " offy = (ids.y + ids.w*" << p_.simd_width << ");" << std::endl;

2
python/setup.py Executable file → Normal file
View File

@@ -115,7 +115,7 @@ def main():
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
#Source files
src = 'src/lib/wrap/clBLAS.cpp src/lib/value_scalar.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/symbolic/execute.cpp src/lib/model/predictors/random_forest.cpp src/lib/model/model.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/driver/event.cpp src/lib/driver/device.cpp src/lib/driver/context.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/backend.cpp src/lib/backend/templates/ger.cpp src/lib/backend/templates/gemv.cpp src/lib/backend/templates/gemm.cpp src/lib/backend/templates/dot.cpp src/lib/backend/templates/base.cpp src/lib/backend/templates/axpy.cpp src/lib/backend/stream.cpp src/lib/backend/parse.cpp src/lib/backend/mapped_object.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp src/lib/array.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
src = 'src/lib/array.cpp src/lib/value_scalar.cpp src/lib/wrap/clBLAS.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/context.cpp src/lib/driver/program.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/driver/event.cpp src/lib/driver/device.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/backend/parse.cpp src/lib/backend/mapped_object.cpp src/lib/backend/templates/gemm.cpp src/lib/backend/templates/base.cpp src/lib/backend/templates/axpy.cpp src/lib/backend/templates/ger.cpp src/lib/backend/templates/gemv.cpp src/lib/backend/templates/dot.cpp src/lib/backend/stream.cpp src/lib/backend/keywords.cpp src/lib/backend/binder.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
boostsrc = 'external/boost/libs/'
for s in ['numpy','python','smart_ptr','system','thread']:
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]