[dnn/conv] No more divergent path in conv::set_arg
This commit is contained in:
@@ -266,6 +266,16 @@ private:
|
||||
bool bias_;
|
||||
bool b_trans_;
|
||||
bool b_lut_;
|
||||
// axis index
|
||||
int32_t a_inner_idx_;
|
||||
int32_t a_outer_idx_;
|
||||
int32_t a_pix_idx_;
|
||||
int32_t b_inner_idx_;
|
||||
int32_t b_outer_idx_;
|
||||
int32_t b_pix_idx_;
|
||||
int32_t c_outer_0_idx_;
|
||||
int32_t c_outer_1_idx_;
|
||||
int32_t c_pix_idx;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -19,11 +19,22 @@ conv::conv(int B, int NC,
|
||||
CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
|
||||
CH_ = (AH_*upsample_h_ - BH_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_;
|
||||
CW_ = (AW_*upsample_w_ - BW_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_;
|
||||
|
||||
// shapes
|
||||
shapes_a_ = {NB_, NC_, AD_, AH_, AW_};
|
||||
shapes_b_ = {NC_, BD_, BH_, BW_, NF_};
|
||||
shapes_c_ = {NB_, NF_, CD_, CH_, CW_};
|
||||
// a layout - NCHW
|
||||
a_outer_idx_ = 0;
|
||||
a_inner_idx_ = 1;
|
||||
a_pix_idx_ = 2;
|
||||
// b layout - CRSK
|
||||
b_inner_idx_ = 0;
|
||||
b_pix_idx_ = 1;
|
||||
b_outer_idx_ = 4;
|
||||
// c layout - NKPQ
|
||||
c_outer_0_idx_ = 0;
|
||||
c_outer_1_idx_ = 1;
|
||||
c_pix_idx = 2;
|
||||
// swap a and c for bprop
|
||||
if(ty_ == BPROP){
|
||||
std::swap(AD_, CD_);
|
||||
@@ -40,6 +51,10 @@ conv::conv(int B, int NC,
|
||||
std::swap(BD_, CD_);
|
||||
std::swap(BH_, CH_);
|
||||
std::swap(BW_, CW_);
|
||||
std::swap(a_outer_idx_, a_inner_idx_);
|
||||
std::swap(b_inner_idx_, c_outer_0_idx_);
|
||||
std::swap(b_outer_idx_, c_outer_1_idx_);
|
||||
std::swap(b_pix_idx_, c_pix_idx);
|
||||
}
|
||||
// leading dimensions
|
||||
auto set_ld = [](const std::vector<int32_t>& shapes,
|
||||
@@ -250,51 +265,30 @@ void conv::set_arg(driver::kernel *kernel,
|
||||
kernel->setArg(11, CH_);
|
||||
kernel->setArg(12, CW_);
|
||||
// A arguments
|
||||
if(ty_ == WGRAD){
|
||||
kernel->setArg(13, ld_a_[1]);
|
||||
kernel->setArg(14, ld_a_[0]);
|
||||
}
|
||||
else{
|
||||
kernel->setArg(13, ld_a_[0]);
|
||||
kernel->setArg(14, ld_a_[1]);
|
||||
}
|
||||
kernel->setArg(13, ld_a_[a_outer_idx_]);
|
||||
kernel->setArg(14, ld_a_[a_inner_idx_]);
|
||||
kernel->setArg(15, ld_a_[2]);
|
||||
kernel->setArg(16, ld_a_[3]);
|
||||
kernel->setArg(17, ld_a_[4]);
|
||||
// B arguments
|
||||
if(ty_ == WGRAD){
|
||||
kernel->setArg(18, ld_b_[0]);
|
||||
kernel->setArg(19, ld_b_[2]);
|
||||
kernel->setArg(20, ld_b_[3]);
|
||||
kernel->setArg(21, ld_b_[4]);
|
||||
kernel->setArg(22, ld_b_[1]);
|
||||
}
|
||||
else{
|
||||
kernel->setArg(18, ld_b_[0]);
|
||||
kernel->setArg(19, ld_b_[1]);
|
||||
kernel->setArg(20, ld_b_[2]);
|
||||
kernel->setArg(21, ld_b_[3]);
|
||||
kernel->setArg(22, ld_b_[4]);
|
||||
}
|
||||
kernel->setArg(18, ld_b_[b_inner_idx_]);
|
||||
kernel->setArg(19, ld_b_[b_pix_idx_]);
|
||||
kernel->setArg(20, ld_b_[b_pix_idx_+1]);
|
||||
kernel->setArg(21, ld_b_[b_pix_idx_+2]);
|
||||
kernel->setArg(22, ld_b_[b_outer_idx_]);
|
||||
// C arguments
|
||||
if(ty_ == WGRAD){
|
||||
kernel->setArg(23, ld_c_[0]);
|
||||
kernel->setArg(24, ld_c_[4]);
|
||||
kernel->setArg(25, ld_c_[1]);
|
||||
kernel->setArg(26, ld_c_[2]);
|
||||
kernel->setArg(27, ld_c_[3]);
|
||||
}
|
||||
else{
|
||||
kernel->setArg(23, ld_c_[0]);
|
||||
kernel->setArg(24, ld_c_[1]);
|
||||
kernel->setArg(25, ld_c_[2]);
|
||||
kernel->setArg(26, ld_c_[3]);
|
||||
kernel->setArg(27, ld_c_[4]);
|
||||
}
|
||||
kernel->setArg(23, ld_c_[c_outer_0_idx_]);
|
||||
kernel->setArg(24, ld_c_[c_outer_1_idx_]);
|
||||
kernel->setArg(25, ld_c_[c_pix_idx]);
|
||||
kernel->setArg(26, ld_c_[c_pix_idx+1]);
|
||||
kernel->setArg(27, ld_c_[c_pix_idx+2]);
|
||||
// pad
|
||||
kernel->setArg(28, pad_h_);
|
||||
kernel->setArg(29, pad_w_);
|
||||
// stride
|
||||
kernel->setArg(30, stride_h_);
|
||||
kernel->setArg(31, stride_w_);
|
||||
// dilate
|
||||
kernel->setArg(32, upsample_h_);
|
||||
kernel->setArg(33, upsample_w_);
|
||||
size_t idx = 34;
|
||||
|
Reference in New Issue
Block a user