[examples/pytorch] Fixed issues in backward pass of conv

This commit is contained in:
Philippe Tillet
2019-05-19 01:31:08 -04:00
parent b2b55c52c9
commit f33a1f3fe3
9 changed files with 541 additions and 71 deletions

View File

@@ -24,10 +24,13 @@ conv::conv(int B, int NC,
shapes_c_ = {NB_, NF_, CD_, CH_, CW_};
// swap a and c for bprop
if(ty_ == BPROP){
pad_d_ = (CD_ - AD_ + BD_ - 1) / 2;
pad_h_ = (CH_ - AH_ + BH_ - 1) / 2;
pad_w_ = (CW_ - AW_ + BW_ - 1) / 2;
std::swap(AD_, CD_);
std::swap(AH_, CH_);
std::swap(AW_, CW_);
shapes_a_.swap(shapes_c_);
pad_d_ = (CD_*stride_d_ - AD_*upsample_d_ + BD_ - 1 - stride_d_ + 1)/2;
pad_h_ = (CH_*stride_h_ - AH_*upsample_h_ + BH_ - 1 - stride_h_ + 1)/2;
pad_w_ = (CW_*stride_w_ - AW_*upsample_w_ + BW_ - 1 - stride_w_ + 1)/2;
}
// swap b and c for wgrad
if(ty_ == WGRAD){