From 2672812ad01c32013c201daf1289982e3a8e2291 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 22 May 2019 15:25:43 -0400 Subject: [PATCH] [dnn/conv] No more divergent path in conv::set_arg --- include/triton/dnn/conv.h | 10 ++++++ lib/dnn/conv.cpp | 68 ++++++++++++++++++--------------------- 2 files changed, 41 insertions(+), 37 deletions(-) diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index 313065fc6..af8861bb8 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -266,6 +266,16 @@ private: bool bias_; bool b_trans_; bool b_lut_; + // axis index + int32_t a_inner_idx_; + int32_t a_outer_idx_; + int32_t a_pix_idx_; + int32_t b_inner_idx_; + int32_t b_outer_idx_; + int32_t b_pix_idx_; + int32_t c_outer_0_idx_; + int32_t c_outer_1_idx_; + int32_t c_pix_idx; }; } diff --git a/lib/dnn/conv.cpp b/lib/dnn/conv.cpp index fc97acbba..2cdfecb63 100644 --- a/lib/dnn/conv.cpp +++ b/lib/dnn/conv.cpp @@ -19,11 +19,22 @@ conv::conv(int B, int NC, CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_; CH_ = (AH_*upsample_h_ - BH_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_; CW_ = (AW_*upsample_w_ - BW_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_; - // shapes shapes_a_ = {NB_, NC_, AD_, AH_, AW_}; shapes_b_ = {NC_, BD_, BH_, BW_, NF_}; shapes_c_ = {NB_, NF_, CD_, CH_, CW_}; + // a layout - NCHW + a_outer_idx_ = 0; + a_inner_idx_ = 1; + a_pix_idx_ = 2; + // b layout - CRSK + b_inner_idx_ = 0; + b_pix_idx_ = 1; + b_outer_idx_ = 4; + // c layout - NKPQ + c_outer_0_idx_ = 0; + c_outer_1_idx_ = 1; + c_pix_idx = 2; // swap a and c for bprop if(ty_ == BPROP){ std::swap(AD_, CD_); @@ -40,6 +51,10 @@ conv::conv(int B, int NC, std::swap(BD_, CD_); std::swap(BH_, CH_); std::swap(BW_, CW_); + std::swap(a_outer_idx_, a_inner_idx_); + std::swap(b_inner_idx_, c_outer_0_idx_); + std::swap(b_outer_idx_, c_outer_1_idx_); + std::swap(b_pix_idx_, c_pix_idx); } // leading dimensions auto set_ld = [](const std::vector& shapes, @@ -250,51 +265,30 @@ void conv::set_arg(driver::kernel *kernel, kernel->setArg(11, CH_); kernel->setArg(12, CW_); // A arguments - if(ty_ == WGRAD){ - kernel->setArg(13, ld_a_[1]); - kernel->setArg(14, ld_a_[0]); - } - else{ - kernel->setArg(13, ld_a_[0]); - kernel->setArg(14, ld_a_[1]); - } + kernel->setArg(13, ld_a_[a_outer_idx_]); + kernel->setArg(14, ld_a_[a_inner_idx_]); kernel->setArg(15, ld_a_[2]); kernel->setArg(16, ld_a_[3]); kernel->setArg(17, ld_a_[4]); // B arguments - if(ty_ == WGRAD){ - kernel->setArg(18, ld_b_[0]); - kernel->setArg(19, ld_b_[2]); - kernel->setArg(20, ld_b_[3]); - kernel->setArg(21, ld_b_[4]); - kernel->setArg(22, ld_b_[1]); - } - else{ - kernel->setArg(18, ld_b_[0]); - kernel->setArg(19, ld_b_[1]); - kernel->setArg(20, ld_b_[2]); - kernel->setArg(21, ld_b_[3]); - kernel->setArg(22, ld_b_[4]); - } + kernel->setArg(18, ld_b_[b_inner_idx_]); + kernel->setArg(19, ld_b_[b_pix_idx_]); + kernel->setArg(20, ld_b_[b_pix_idx_+1]); + kernel->setArg(21, ld_b_[b_pix_idx_+2]); + kernel->setArg(22, ld_b_[b_outer_idx_]); // C arguments - if(ty_ == WGRAD){ - kernel->setArg(23, ld_c_[0]); - kernel->setArg(24, ld_c_[4]); - kernel->setArg(25, ld_c_[1]); - kernel->setArg(26, ld_c_[2]); - kernel->setArg(27, ld_c_[3]); - } - else{ - kernel->setArg(23, ld_c_[0]); - kernel->setArg(24, ld_c_[1]); - kernel->setArg(25, ld_c_[2]); - kernel->setArg(26, ld_c_[3]); - kernel->setArg(27, ld_c_[4]); - } + kernel->setArg(23, ld_c_[c_outer_0_idx_]); + kernel->setArg(24, ld_c_[c_outer_1_idx_]); + kernel->setArg(25, ld_c_[c_pix_idx]); + kernel->setArg(26, ld_c_[c_pix_idx+1]); + kernel->setArg(27, ld_c_[c_pix_idx+2]); + // pad kernel->setArg(28, pad_h_); kernel->setArg(29, pad_w_); + // stride kernel->setArg(30, stride_h_); kernel->setArg(31, stride_w_); + // dilate kernel->setArg(32, upsample_h_); kernel->setArg(33, upsample_w_); size_t idx = 34;