From d8526669f5a656e7869625c110521d637b5be850 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 27 Jun 2019 12:39:17 -0700 Subject: [PATCH] fixup --- examples/cpp/shift.cpp | 2 +- lib/dnn/shift.cpp | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index ed949c74c..0b523f826 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -99,7 +99,7 @@ int main() { // shift std::vector params = { - 8, 2, 16, 8, 2, 32, 8, 4, 2, 2, 4, 2, 8, 4 + 8, 2, 32, 8, 2, 64, 8, 4, 2, 2, 4, 2, 8, 4 }; std::ostringstream oss; shift.src(oss); diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index d07809f1f..1fc3645ec 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -8,8 +8,7 @@ void shift::set_ld(const std::vector& shapes, std::vector& ld) { size_t size = shapes.size(); ld.resize(size); - ld[4] = 1; - ld[3] = shapes[4]*ld[4]; + ld[3] = 1; ld[2] = shapes[3]*ld[3]; ld[1] = shapes[2]*ld[2]; ld[0] = shapes[1]*ld[1]; @@ -42,6 +41,9 @@ shift::shift(int B, int NC, shapes_c_ = {NF, H, W, B}; // memory strides set_ld(shapes_a_, ld_a_); + // build LUTs + build_deltas(); + build_masks(); } void shift::build_deltas() { @@ -89,7 +91,7 @@ std::vector shift::c_shapes(){ } size_t shift::get_nflops() { - return 2 * M_ * N_ * K_; + return 2. * M_ * N_ * K_; }