[examples/pytorch] Fixed issues in backward pass of conv
This commit is contained in:
@@ -24,10 +24,13 @@ conv::conv(int B, int NC,
|
||||
shapes_c_ = {NB_, NF_, CD_, CH_, CW_};
|
||||
// swap a and c for bprop
|
||||
if(ty_ == BPROP){
|
||||
pad_d_ = (CD_ - AD_ + BD_ - 1) / 2;
|
||||
pad_h_ = (CH_ - AH_ + BH_ - 1) / 2;
|
||||
pad_w_ = (CW_ - AW_ + BW_ - 1) / 2;
|
||||
std::swap(AD_, CD_);
|
||||
std::swap(AH_, CH_);
|
||||
std::swap(AW_, CW_);
|
||||
shapes_a_.swap(shapes_c_);
|
||||
pad_d_ = (CD_*stride_d_ - AD_*upsample_d_ + BD_ - 1 - stride_d_ + 1)/2;
|
||||
pad_h_ = (CH_*stride_h_ - AH_*upsample_h_ + BH_ - 1 - stride_h_ + 1)/2;
|
||||
pad_w_ = (CW_*stride_w_ - AW_*upsample_w_ + BW_ - 1 - stride_w_ + 1)/2;
|
||||
}
|
||||
// swap b and c for wgrad
|
||||
if(ty_ == WGRAD){
|
||||
|
Reference in New Issue
Block a user