From 21fd0fd65e12aa2f9dcb6fa749099e205865b0af Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 28 Jun 2019 11:13:36 -0700 Subject: [PATCH] fixup --- examples/cpp/shift.cpp | 6 +++--- include/triton/dnn/shift.h | 5 +++-- lib/dnn/shift.cpp | 3 +++ 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index e23ba5c9c..fa7714782 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -68,12 +68,12 @@ int main() { // shift std::vector params = { - 8, 2, 32, 8, 2, 64, 8, 4, 2, 2, 4, 2, 8, 4 + 4, 2, 32, 8, 2, 32, 8, 4, 2, 2, 8, 8, 4 }; std::ostringstream oss; shift.src(oss); std::string src = oss.str(); - jit.autotune("shift", src.c_str(), benchmark); +// jit.autotune("shift", src.c_str(), benchmark); jit.add_module("shift", src.c_str(), params); triton::driver::kernel* kernel = jit.get_function("shift"); triton::jit::launch_information info = jit.get_launch_info("shift"); @@ -81,7 +81,7 @@ int main() { stream->read(dc, true, 0, hc); shift.cpu_ref(rc.data(), ha.data(), hb.data()); for(size_t i = 0; i < hc.size(); i++) - if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ + if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; exit(EXIT_FAILURE); } diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h index 9b81bf6ad..cec282d34 100644 --- a/include/triton/dnn/shift.h +++ b/include/triton/dnn/shift.h @@ -92,10 +92,11 @@ public: for(int32_t c = 0; c < NC_; ++c){ int32_t h = p; int32_t w = q; - if(h >= BH_/2 && h < AH_ - BH_/2) + if(h >= BH_/2 && h < AH_ - BH_/2 + && w >= BW_/2 && w < AW_ - BW_/2){ h += shift_h_[c]; - if(w > BW_/2 && w < AW_ - BW_/2) w += shift_w_[c]; + } IN_DTYPE a = I[bs + w*NB_ + h*NB_*AW_ + c*NB_*AH_*AW_]; IN_DTYPE b = F[k + c*NF_]; acc = std::fma(a, b, acc); diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index dbaa3f496..87b158648 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -53,6 +53,9 @@ void shift::build_deltas() { h_deltas_[c] += shift_h_[c]*ld_a_[1]; h_deltas_[c] += shift_w_[c]*ld_a_[2]; } + for(unsigned c = 0; c < NC_; c++){ + h_deltas_[c + 256] = c*ld_a_[0]; + } } void shift::build_masks() {