From ec24e1e7df15572f34454b270797fa7f5e812c9e Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 16 Jul 2019 18:47:50 -0700 Subject: [PATCH] trying to remove interior logic --- examples/cpp/shift.cpp | 6 +++--- examples/python/tensorflow/run.py | 4 ++-- lib/codegen/tune.cpp | 2 +- lib/dnn/base.cpp | 2 +- lib/dnn/shift.cpp | 36 +++++++++++-------------------- 5 files changed, 20 insertions(+), 30 deletions(-) diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index 020dba23a..754853a8e 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -14,13 +14,13 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); - auto op = triton::dnn::shift::FPROP; + auto op = triton::dnn::shift::BPROP; // initialization int32_t R = 3, S = 3; - int32_t B = 128, F = 128; + int32_t B = 16, F = 4096; int32_t H = 16, W = 16; - int32_t C = 128; + int32_t C = 4096; // random shifts std::vector shift_h(C); diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 809967f84..5fb1d9314 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -128,6 +128,6 @@ def run_batchnorm(): print(np.max(np.abs(dg_t - dg_n))) print(np.max(np.abs(db_t - db_n))) -run_dot() -#run_shift() +#run_dot() +run_shift() #run_batchnorm() diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index e78440f52..f18afeeac 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -237,7 +237,7 @@ void tune::run(ir::module &mod) { continue; if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 4, 4)); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 2, 4)); *params_.at(i).at("nts.d0") = *tmp; } if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index 9798d8cb3..a3a3ce403 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -51,7 +51,7 @@ void base::enqueue(driver::stream *stream, std::vector args, b jit->add_module(name_.c_str(), src.c_str(), best.params); } else { - jit->add_module(name_.c_str(), src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 16, 8, 1}); + jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str())); } triton::driver::kernel* kernel = jit->get_function(name_.c_str()); clone->init_impl(stream, (triton::driver::cu_module*)kernel->module()); diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index b9e580506..0691a5980 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -80,7 +80,9 @@ shift::shift(int B, int C, throw std::runtime_error("unsupported input layout"); } // Equivalent matmul - M_ = B_*CH_*CW_; + M_ = B_*(CH_ - BH_ / 2)*(CW_ - BW_/2); + if(M_ == 0) + throw std::runtime_error("unsupported input shapes - no interior !"); N_ = F_; K_ = C_; // transpose @@ -288,14 +290,14 @@ void shift::triton_c_src(std::ostream &os) const { return R"( int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB; int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB; - int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW; - int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW;)"; + int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW + pad_w; + int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW + pad_h;)"; } else { return R"( int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW; - int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW; - int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH; + int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW + pad_w; + int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH + pad_h; int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)"; } }; @@ -370,10 +372,7 @@ if(op_ == FPROP){ int32 offa0[TM, TK] = offxa[:, newaxis]; __constant__ int32* pd[TK] = delta_a + rka; multiple_of(4) int32 d[TK] = *pd; - int32 offa_interior[TM, TK] = d[newaxis, :]; - int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c; - )" + compute_interior("ra", "TM", "TK") + R"( - int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)"; + int32 offa1[TM, TK] = d[newaxis, :];)"; } if(op_ == BPROP){ result += @@ -415,10 +414,8 @@ if(op_ == WGRAD){ rbw = rbw * stride_w; rbh = rbh * stride_h; int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; - )" + compute_interior("rb", "TK", "TN") + R"( - int32 incb[TK, TN] = interior ? shift : 0; int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; - int32 offb1[TK, TN] = offkb[:, newaxis] + incb;)"; + int32 offb1[TK, TN] = offkb[:, newaxis] + shift;)"; } /* Main loop */ @@ -439,10 +436,7 @@ if(op_ == FPROP){ result += R"( pd = pd + TK; d = *pd; - offa_interior = d[newaxis, :]; - offa_exterior = TK * lda_c; - int32 offa[TM, TK] = interior ? offa_interior : offa_exterior; - pa = pa + offa;)"; + pa = pa + d[newaxis, :];)"; } if(op_ == BPROP){ result += R"( @@ -470,9 +464,7 @@ if(op_ == WGRAD){ rbw = rbw * stride_w; rbh = rbh * stride_h; offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; - )" + compute_interior("rb", "TK", "TN") + R"( - incb = interior ? shift : 0; - pb = B + offb0 + offkb[:, newaxis] + incb;)"; + pb = B + offb0 + offkb[:, newaxis] + shift;)"; } if(op_ == FPROP){ result += R"( @@ -513,11 +505,9 @@ if(op_ == WGRAD){ int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)"; if(op_ == BPROP){ result += R"( - )" + compute_interior("rc", "TM", "TN") + R"( __constant__ int32* pd[TN] = delta_a + ryc; - )" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :]; - pc = interior ? shift_pc : pc; - @checkc __atomic_add(pc, c); + pc = pc + (*pd)[newaxis, :]; + @checkc *pc = c; )"; } else{