From 07c964919cf78c86514e39ceaeefe7e4568eb317 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 16 Jul 2019 20:18:48 -0700 Subject: [PATCH] [dnn/shift] now strictly only shifting the interior --- examples/python/pytorch/shift.cpp | 2 +- examples/python/tensorflow/shift.cpp | 2 +- include/triton/dnn/shift.h | 8 +++++ lib/codegen/tune.cpp | 4 +-- lib/dnn/shift.cpp | 50 +++++++++++----------------- lib/runtime/jit.cpp | 3 +- 6 files changed, 34 insertions(+), 35 deletions(-) diff --git a/examples/python/pytorch/shift.cpp b/examples/python/pytorch/shift.cpp index 7efe0198b..d25ed588f 100644 --- a/examples/python/pytorch/shift.cpp +++ b/examples/python/pytorch/shift.cpp @@ -75,7 +75,7 @@ torch::Tensor shift_common( triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false); // Enqueue - shift.enqueue(&stream, {&a, &b, &c}, true); + shift.enqueue(&stream, {&a, &b, &c}, false); return torchc; } diff --git a/examples/python/tensorflow/shift.cpp b/examples/python/tensorflow/shift.cpp index d844e9aa1..1834cadaf 100644 --- a/examples/python/tensorflow/shift.cpp +++ b/examples/python/tensorflow/shift.cpp @@ -122,7 +122,7 @@ public: triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat().data(), false); triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat().data(), false); triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat().data(), false); - shift.enqueue(stream, {&da, &db, &dc}); + shift.enqueue(stream, {&da, &db, &dc}, false); } private: diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h index fbff404ca..84c6ccda7 100644 --- a/include/triton/dnn/shift.h +++ b/include/triton/dnn/shift.h @@ -128,6 +128,14 @@ private: int32_t CD_; int32_t CH_; int32_t CW_; + // interior image size + int32_t IAD_; + int32_t IAH_; + int32_t IAW_; + // interior activation size + int32_t ICD_; + int32_t ICH_; + int32_t ICW_; // equivalent matmul int32_t M_; int32_t N_; diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index f18afeeac..6c9522f03 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -223,7 +223,7 @@ void tune::run(ir::module &mod) { } else { ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2); - ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4); + ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4); connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++); } } @@ -237,7 +237,7 @@ void tune::run(ir::module &mod) { continue; if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 2, 4)); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 2, 2)); *params_.at(i).at("nts.d0") = *tmp; } if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index 0691a5980..49212619d 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -79,10 +79,15 @@ shift::shift(int B, int C, default: throw std::runtime_error("unsupported input layout"); } + IAD_ = AD_ - 2*(BD_/2); + IAH_ = AH_ - 2*(BH_/2); + IAW_ = AW_ - 2*(BW_/2); + ICD_ = IAD_ / stride_d_; + ICH_ = IAH_ / stride_h_; + ICW_ = IAW_ / stride_w_; + // Equivalent matmul - M_ = B_*(CH_ - BH_ / 2)*(CW_ - BW_/2); - if(M_ == 0) - throw std::runtime_error("unsupported input shapes - no interior !"); + M_ = B_*ICH_*ICW_; N_ = F_; K_ = C_; // transpose @@ -247,21 +252,21 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel, kernel->setArg(18, ldc_h_); kernel->setArg(19, ldc_f_); kernel->setArg(20, B_); - kernel->setArg(21, AH_); - kernel->setArg(22, AW_); + kernel->setArg(21, IAH_); + kernel->setArg(22, IAW_); kernel->setArg(23, BH_); kernel->setArg(24, BW_); - kernel->setArg(25, CH_); - kernel->setArg(26, CW_); + kernel->setArg(25, ICH_); + kernel->setArg(26, ICW_); kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_); kernel->setArg(28, (int32_t)grid[0]); kernel->setArg(29, (int32_t)grid[1]); kernel->setArg(30, (int32_t)grid[2]); if(locks_) ((driver::cu_buffer*)locks_)->set_zero(stream, 2*max_locks_*4); - if(op_ == BPROP){ + if(op_ == FPROP || op_ == BPROP){ size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4; - ((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes); + ((driver::cu_buffer*)c)->set_zero(stream, c_size()*c_nbytes); } stream->enqueue(kernel, grid, {info.num_threads, 1, 1}); } @@ -290,33 +295,18 @@ void shift::triton_c_src(std::ostream &os) const { return R"( int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB; int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB; - int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW + pad_w; - int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW + pad_h;)"; + int32 )" + rx + "w[" + sz + "] = (" + rx + R"(wh % CW) + pad_w; + int32 )" + rx + "h[" + sz + "] = (" + rx + R"(wh / CW) + pad_h;)"; } else { return R"( int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW; - int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW + pad_w; - int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH + pad_h; + int32 )" + rx + "w[" + sz + "] = (" + rkx + R"( % CW) + pad_w; + int32 )" + rx + "h[" + sz + "] = (" + rx + R"(bh % CH) + pad_h; int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)"; } }; - auto compute_interior = [&](std::string rx, std::string sz0, std::string sz1) { - std::string result; - if(shift_edge_h_) - result += "int1 interiorh[" + sz0 + "] = 1;\n "; - else - result += "int1 interiorh[" + sz0 + "] = (" + rx + "h >= pad_h) && (" + rx + "h < (AH - pad_h));\n "; - if(shift_edge_w_) - result += "int1 interiorw[" + sz0 + "] = 1;"; - else - result += "int1 interiorw[" + sz0 + "] = (" + rx + "w >= pad_w) && (" + rx + "w < (AW - pad_w));"; - result += R"( - int1 interior[)" + sz0 + ", " + sz1 + "] = interiorh[:, newaxis] && interiorw[:, newaxis];"; - return result; - }; - std::string result = R"( const tunable int32 TM = {16, 32, 64, 128}; @@ -506,8 +496,8 @@ if(op_ == WGRAD){ if(op_ == BPROP){ result += R"( __constant__ int32* pd[TN] = delta_a + ryc; - pc = pc + (*pd)[newaxis, :]; - @checkc *pc = c; + )" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :]; + @checkc *shift_pc = c; )"; } else{ diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index 14bec7172..e42c534b6 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -43,7 +43,8 @@ void loop_nest(std::vector const & ranges, // size_t current = 0; while(true){ //Execute function - pool.enqueue([values, &f](){ f(values); }); +// pool.enqueue([values, &f](){ f(values); }); + f(values); while(values[i]++ == ranges[i] - 1){ if(i == 0) return;