Further fixes

This commit is contained in:
Philippe Tillet
2016-10-03 03:24:49 -04:00
parent fca79c317e
commit 1852ddef72
3 changed files with 11 additions and 11 deletions

View File

@@ -50,9 +50,11 @@ public:
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &ctr); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &ctr);
private: private:
//Parameters //Parameters
unsigned int kL_;
unsigned int depth_;
unsigned int mL_;
unsigned int kL_;
unsigned int nL_;
unsigned int depth_;
unsigned int mS_; unsigned int mS_;
unsigned int kS_; unsigned int kS_;
unsigned int nS_; unsigned int nS_;
@@ -63,9 +65,6 @@ private:
unsigned int lf0_; unsigned int lf0_;
unsigned int lf1_; unsigned int lf1_;
unsigned int mL_;
unsigned int nL_;
bool prefetch_; bool prefetch_;
bool unroll_outer_; bool unroll_outer_;
// //

View File

@@ -120,7 +120,8 @@ namespace templates
bool has_depth = depth_ > 1; bool has_depth = depth_ > 1;
#define VLOAD(offset, ptr) vload(vwidth_, sdtype, offset, ptr, "1", backend, true) #define VLOAD(offset, ptr) vload(vwidth_, sdtype, offset, ptr, "1", backend, true)
#define VLOAD_MISALIGNED(offset, ptr) vload(vwidth_, sdtype, offset, ptr, "1", backend, false) #define VLOAD_MISALIGNED(offset, ptr) vload(vwidth_, sdtype, offset, ptr, "1", backend, false)
#define VSTORE(value, offset, ptr) vstore(vwidth_, sdtype, value, offset, ptr, "1", backend) #define VSTORE_LDSA(value, offset, ptr) vstore(vwidth_, sdtype, value, offset, ptr, "1", backend, llda%vwidth_==0)
#define VSTORE_LDSB(value, offset, ptr) vstore(vwidth_, sdtype, value, offset, ptr, "1", backend, lldb%vwidth_==0)
symbolic::preset::gemm::args args; symbolic::preset::gemm::args args;
infos(tree, args); infos(tree, args);
@@ -315,7 +316,7 @@ namespace templates
for(unsigned int s = 0 ; s < vwidth_ ; ++s) for(unsigned int s = 0 ; s < vwidth_ ; ++s)
stream << "ldsA[" << k*llda + m + s << "] = (condy" << k << " && " << s << "< M)? Ai[" << mm << "][" << k << "*lda + " << s << "] : 0;" << std::endl; stream << "ldsA[" << k*llda + m + s << "] = (condy" << k << " && " << s << "< M)? Ai[" << mm << "][" << k << "*lda + " << s << "] : 0;" << std::endl;
else else
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Ai[" + mm +"][" + kk + "*lda]"), "0", "ldsA + " + to_string(k*llda+m)) << ";" << std::endl; stream << VSTORE_LDSA(VLOAD_MISALIGNED("0" ,"&Ai[" + mm +"][" + kk + "*lda]"), "0", "ldsA + " + to_string(k*llda+m)) << ";" << std::endl;
} }
} }
else else
@@ -330,7 +331,7 @@ namespace templates
stream << "ldsA[" << m*llda + k + s << "] = condx" << k + s << "? Ai[" << mm << "][" << k + s << ASTRIDE1 << "] : 0;" << std::endl; stream << "ldsA[" << m*llda + k + s << "] = condx" << k + s << "? Ai[" << mm << "][" << k + s << ASTRIDE1 << "] : 0;" << std::endl;
else else
stream << VSTORE(VLOAD_MISALIGNED("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]"), "0", "ldsA + " + to_string(m*llda+k)) << ";" << std::endl; stream << VSTORE_LDSA(VLOAD_MISALIGNED("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]"), "0", "ldsA + " + to_string(m*llda+k)) << ";" << std::endl;
} }
} }
@@ -346,7 +347,7 @@ namespace templates
for(unsigned int s = 0 ; s < vwidth_ ; ++s) for(unsigned int s = 0 ; s < vwidth_ ; ++s)
stream << "ldsB[" << k*lldb + n + s << "] = (condy" << k << " && " << s << "< N)? Bi[" << nn << "][" << kk << "*ldb +" << s << "] : 0;" << std::endl; stream << "ldsB[" << k*lldb + n + s << "] = (condy" << k << " && " << s << "< N)? Bi[" << nn << "][" << kk << "*ldb +" << s << "] : 0;" << std::endl;
else else
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Bi[" + nn +"][" + kk + "*ldb]"), "0", "ldsB + " + to_string(k*lldb+n)) << ";" << std::endl; stream << VSTORE_LDSB(VLOAD_MISALIGNED("0" ,"&Bi[" + nn +"][" + kk + "*ldb]"), "0", "ldsB + " + to_string(k*lldb+n)) << ";" << std::endl;
} }
} }
else else
@@ -361,7 +362,7 @@ namespace templates
stream << "ldsB[" << n*lldb + k + s << "] = condx" << k + s << "? Bi[" << nn << "][" << k + s << BSTRIDE1 << "] : 0;" << std::endl; stream << "ldsB[" << n*lldb + k + s << "] = condx" << k + s << "? Bi[" << nn << "][" << k + s << BSTRIDE1 << "] : 0;" << std::endl;
else else
stream << VSTORE(VLOAD_MISALIGNED("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]"), "0", "ldsB + " + to_string(n*lldb+k)) << ";" << std::endl; stream << VSTORE_LDSB(VLOAD_MISALIGNED("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]"), "0", "ldsB + " + to_string(n*lldb+k)) << ";" << std::endl;
} }
} }

View File

@@ -73,7 +73,7 @@ def main():
libraries += ['gnustl_shared'] libraries += ['gnustl_shared']
#Source files #Source files
src = 'src/lib/random/rand.cpp src/lib/jit/syntax/expression/preset.cpp src/lib/jit/syntax/expression/expression.cpp src/lib/jit/syntax/expression/operations.cpp src/lib/jit/syntax/engine/macro.cpp src/lib/jit/syntax/engine/object.cpp src/lib/jit/syntax/engine/process.cpp src/lib/jit/syntax/engine/binder.cpp src/lib/jit/generation/reduce_2d.cpp src/lib/jit/generation/elementwise_2d.cpp src/lib/jit/generation/engine/stream.cpp src/lib/jit/generation/engine/keywords.cpp src/lib/jit/generation/elementwise_1d.cpp src/lib/jit/generation/reduce_1d.cpp src/lib/jit/generation/gemm.cpp src/lib/jit/generation/base.cpp src/lib/runtime/execute.cpp src/lib/runtime/database.cpp src/lib/runtime/profiles.cpp src/lib/runtime/predictors/random_forest.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/driver/backend.cpp src/lib/driver/device.cpp src/lib/driver/kernel.cpp src/lib/driver/buffer.cpp src/lib/driver/platform.cpp src/lib/driver/check.cpp src/lib/driver/program.cpp src/lib/driver/command_queue.cpp src/lib/driver/dispatch.cpp src/lib/driver/program_cache.cpp src/lib/driver/context.cpp src/lib/driver/event.cpp src/lib/driver/ndrange.cpp src/lib/driver/handle.cpp src/lib/api/blas/clBLAS.cpp src/lib/api/blas/cublas.cpp src/lib/exception/api.cpp src/lib/exception/driver.cpp '.split() + [os.path.join('src', 'bind', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'kernels.cpp', 'exceptions.cpp']] src = 'src/lib/runtime/predictors/random_forest.cpp src/lib/runtime/profiles.cpp src/lib/runtime/database.cpp src/lib/runtime/execute.cpp src/lib/exception/driver.cpp src/lib/exception/api.cpp src/lib/random/rand.cpp src/lib/jit/generation/elementwise_1d.cpp src/lib/jit/generation/reduce_2d.cpp src/lib/jit/generation/reduce_1d.cpp src/lib/jit/generation/base.cpp src/lib/jit/generation/gemm.cpp src/lib/jit/generation/engine/keywords.cpp src/lib/jit/generation/engine/stream.cpp src/lib/jit/generation/elementwise_2d.cpp src/lib/jit/syntax/expression/expression.cpp src/lib/jit/syntax/expression/preset.cpp src/lib/jit/syntax/expression/operations.cpp src/lib/jit/syntax/engine/binder.cpp src/lib/jit/syntax/engine/macro.cpp src/lib/jit/syntax/engine/process.cpp src/lib/jit/syntax/engine/object.cpp src/lib/value_scalar.cpp src/lib/array.cpp src/lib/api/blas/cublas.cpp src/lib/api/blas/clBLAS.cpp src/lib/driver/dispatch.cpp src/lib/driver/kernel.cpp src/lib/driver/backend.cpp src/lib/driver/platform.cpp src/lib/driver/buffer.cpp src/lib/driver/event.cpp src/lib/driver/ndrange.cpp src/lib/driver/device.cpp src/lib/driver/program_cache.cpp src/lib/driver/check.cpp src/lib/driver/command_queue.cpp src/lib/driver/handle.cpp src/lib/driver/context.cpp src/lib/driver/program.cpp '.split() + [os.path.join('src', 'bind', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'kernels.cpp', 'exceptions.cpp']]
boostsrc = 'external/boost/libs/' boostsrc = 'external/boost/libs/'
for s in ['numpy','python','smart_ptr','system','thread']: for s in ['numpy','python','smart_ptr','system','thread']:
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x] src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]