[examples/pytorch] Fixed issues in backward pass of conv

This commit is contained in:
Philippe Tillet
2019-05-19 01:31:08 -04:00
parent b2b55c52c9
commit f33a1f3fe3
9 changed files with 541 additions and 71 deletions

View File

@@ -86,7 +86,8 @@ torch::Tensor conv_common(
torch::Tensor conv_fprop(
const torch::Tensor data,
const torch::Tensor weight) {
const torch::Tensor weight,
int64_t pad_h, int64_t pad_w) {
// Check
CHECK_INPUT(data);
CHECK_INPUT(weight);
@@ -104,7 +105,7 @@ torch::Tensor conv_fprop(
const int32_t NF = weight.size(3);
// Configuration
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
const int32_t pad_d = 0;
// Check
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight);
@@ -112,7 +113,8 @@ torch::Tensor conv_fprop(
torch::Tensor conv_bprop(
const torch::Tensor derror,
const torch::Tensor weight){
const torch::Tensor weight,
int64_t pad_h, int64_t pad_w){
// Check
CHECK_INPUT(derror);
CHECK_INPUT(weight);
@@ -131,10 +133,12 @@ torch::Tensor conv_bprop(
// Compute M, P, Q
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
const int32_t D = M*stride_d + T - 1 - 2*pad_d + stride_d - 1 / upsample_d;
const int32_t H = P*stride_d + R - 1 - 2*pad_h + stride_h - 1 / upsample_h;
const int32_t W = Q*stride_d + S - 1 - 2*pad_w + stride_w - 1 / upsample_w;
int32_t pad_d = 0;
const int32_t D = (M*stride_d + T - 1 - 2*pad_d - stride_d + 1) / upsample_d;
const int32_t H = (P*stride_d + R - 1 - 2*pad_h - stride_h + 1) / upsample_h;
const int32_t W = (Q*stride_d + S - 1 - 2*pad_w - stride_w + 1) / upsample_w;
// Check
AT_CHECK(Ki == Kw, "Number of channels in error and weights must match");
return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight);
@@ -142,17 +146,18 @@ torch::Tensor conv_bprop(
torch::Tensor conv_wgrad(
const torch::Tensor data,
const torch::Tensor derror
const torch::Tensor derror,
int64_t pad_h, int64_t pad_w
){
// Check
CHECK_INPUT(data);
CHECK_INPUT(derror);
// Unpack data shapes
const int32_t Ba = derror.size(0);
const int32_t C = derror.size(1);
const int32_t Ba = data.size(0);
const int32_t C = data.size(1);
const int32_t D = 1;
const int32_t H = derror.size(2);
const int32_t W = derror.size(3);
const int32_t H = data.size(2);
const int32_t W = data.size(3);
// Unpack error shapes
const int32_t Bb = derror.size(0);
const int32_t K = derror.size(1);
@@ -162,10 +167,12 @@ torch::Tensor conv_wgrad(
// Compute M, P, Q
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
const int32_t T = (D - M*stride_d + 1 + 2*pad_d - stride_d + 1)*upsample_d;
const int32_t R = (H - P*stride_h + 1 + 2*pad_h - stride_h + 1)*upsample_h;
const int32_t S = (W - Q*stride_w + 1 + 2*pad_w - stride_w + 1)*upsample_w;
const int32_t pad_d = 0;
const int32_t T = (D - M*stride_d + 1 + 2*pad_d + stride_d - 1)*upsample_d;
const int32_t R = (H - P*stride_h + 1 + 2*pad_h + stride_h - 1)*upsample_h;
const int32_t S = (W - Q*stride_w + 1 + 2*pad_w + stride_w - 1)*upsample_w;
// Check
AT_CHECK(Ba == Bb, "Number of channels in error and weights must match");
return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror);