[dnn/conv] No more divergent path in conv::set_arg

This commit is contained in:
Philippe Tillet
2019-05-22 15:25:43 -04:00
parent e8f23bcade
commit 2672812ad0
2 changed files with 41 additions and 37 deletions

View File

@@ -266,6 +266,16 @@ private:
bool bias_;
bool b_trans_;
bool b_lut_;
// axis index
int32_t a_inner_idx_;
int32_t a_outer_idx_;
int32_t a_pix_idx_;
int32_t b_inner_idx_;
int32_t b_outer_idx_;
int32_t b_pix_idx_;
int32_t c_outer_0_idx_;
int32_t c_outer_1_idx_;
int32_t c_pix_idx;
};
}

View File

@@ -19,11 +19,22 @@ conv::conv(int B, int NC,
CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
CH_ = (AH_*upsample_h_ - BH_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_;
CW_ = (AW_*upsample_w_ - BW_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_;
// shapes
shapes_a_ = {NB_, NC_, AD_, AH_, AW_};
shapes_b_ = {NC_, BD_, BH_, BW_, NF_};
shapes_c_ = {NB_, NF_, CD_, CH_, CW_};
// a layout - NCHW
a_outer_idx_ = 0;
a_inner_idx_ = 1;
a_pix_idx_ = 2;
// b layout - CRSK
b_inner_idx_ = 0;
b_pix_idx_ = 1;
b_outer_idx_ = 4;
// c layout - NKPQ
c_outer_0_idx_ = 0;
c_outer_1_idx_ = 1;
c_pix_idx = 2;
// swap a and c for bprop
if(ty_ == BPROP){
std::swap(AD_, CD_);
@@ -40,6 +51,10 @@ conv::conv(int B, int NC,
std::swap(BD_, CD_);
std::swap(BH_, CH_);
std::swap(BW_, CW_);
std::swap(a_outer_idx_, a_inner_idx_);
std::swap(b_inner_idx_, c_outer_0_idx_);
std::swap(b_outer_idx_, c_outer_1_idx_);
std::swap(b_pix_idx_, c_pix_idx);
}
// leading dimensions
auto set_ld = [](const std::vector<int32_t>& shapes,
@@ -250,51 +265,30 @@ void conv::set_arg(driver::kernel *kernel,
kernel->setArg(11, CH_);
kernel->setArg(12, CW_);
// A arguments
if(ty_ == WGRAD){
kernel->setArg(13, ld_a_[1]);
kernel->setArg(14, ld_a_[0]);
}
else{
kernel->setArg(13, ld_a_[0]);
kernel->setArg(14, ld_a_[1]);
}
kernel->setArg(13, ld_a_[a_outer_idx_]);
kernel->setArg(14, ld_a_[a_inner_idx_]);
kernel->setArg(15, ld_a_[2]);
kernel->setArg(16, ld_a_[3]);
kernel->setArg(17, ld_a_[4]);
// B arguments
if(ty_ == WGRAD){
kernel->setArg(18, ld_b_[0]);
kernel->setArg(19, ld_b_[2]);
kernel->setArg(20, ld_b_[3]);
kernel->setArg(21, ld_b_[4]);
kernel->setArg(22, ld_b_[1]);
}
else{
kernel->setArg(18, ld_b_[0]);
kernel->setArg(19, ld_b_[1]);
kernel->setArg(20, ld_b_[2]);
kernel->setArg(21, ld_b_[3]);
kernel->setArg(22, ld_b_[4]);
}
kernel->setArg(18, ld_b_[b_inner_idx_]);
kernel->setArg(19, ld_b_[b_pix_idx_]);
kernel->setArg(20, ld_b_[b_pix_idx_+1]);
kernel->setArg(21, ld_b_[b_pix_idx_+2]);
kernel->setArg(22, ld_b_[b_outer_idx_]);
// C arguments
if(ty_ == WGRAD){
kernel->setArg(23, ld_c_[0]);
kernel->setArg(24, ld_c_[4]);
kernel->setArg(25, ld_c_[1]);
kernel->setArg(26, ld_c_[2]);
kernel->setArg(27, ld_c_[3]);
}
else{
kernel->setArg(23, ld_c_[0]);
kernel->setArg(24, ld_c_[1]);
kernel->setArg(25, ld_c_[2]);
kernel->setArg(26, ld_c_[3]);
kernel->setArg(27, ld_c_[4]);
}
kernel->setArg(23, ld_c_[c_outer_0_idx_]);
kernel->setArg(24, ld_c_[c_outer_1_idx_]);
kernel->setArg(25, ld_c_[c_pix_idx]);
kernel->setArg(26, ld_c_[c_pix_idx+1]);
kernel->setArg(27, ld_c_[c_pix_idx+2]);
// pad
kernel->setArg(28, pad_h_);
kernel->setArg(29, pad_w_);
// stride
kernel->setArg(30, stride_h_);
kernel->setArg(31, stride_w_);
// dilate
kernel->setArg(32, upsample_h_);
kernel->setArg(33, upsample_w_);
size_t idx = 34;