[examples/pytorch] Fixed issues in backward pass of conv
This commit is contained in:
@@ -86,7 +86,8 @@ torch::Tensor conv_common(
|
||||
|
||||
torch::Tensor conv_fprop(
|
||||
const torch::Tensor data,
|
||||
const torch::Tensor weight) {
|
||||
const torch::Tensor weight,
|
||||
int64_t pad_h, int64_t pad_w) {
|
||||
// Check
|
||||
CHECK_INPUT(data);
|
||||
CHECK_INPUT(weight);
|
||||
@@ -104,7 +105,7 @@ torch::Tensor conv_fprop(
|
||||
const int32_t NF = weight.size(3);
|
||||
// Configuration
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
const int32_t pad_d = 0;
|
||||
// Check
|
||||
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
|
||||
return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight);
|
||||
@@ -112,7 +113,8 @@ torch::Tensor conv_fprop(
|
||||
|
||||
torch::Tensor conv_bprop(
|
||||
const torch::Tensor derror,
|
||||
const torch::Tensor weight){
|
||||
const torch::Tensor weight,
|
||||
int64_t pad_h, int64_t pad_w){
|
||||
// Check
|
||||
CHECK_INPUT(derror);
|
||||
CHECK_INPUT(weight);
|
||||
@@ -131,10 +133,12 @@ torch::Tensor conv_bprop(
|
||||
// Compute M, P, Q
|
||||
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
const int32_t D = M*stride_d + T - 1 - 2*pad_d + stride_d - 1 / upsample_d;
|
||||
const int32_t H = P*stride_d + R - 1 - 2*pad_h + stride_h - 1 / upsample_h;
|
||||
const int32_t W = Q*stride_d + S - 1 - 2*pad_w + stride_w - 1 / upsample_w;
|
||||
int32_t pad_d = 0;
|
||||
const int32_t D = (M*stride_d + T - 1 - 2*pad_d - stride_d + 1) / upsample_d;
|
||||
const int32_t H = (P*stride_d + R - 1 - 2*pad_h - stride_h + 1) / upsample_h;
|
||||
const int32_t W = (Q*stride_d + S - 1 - 2*pad_w - stride_w + 1) / upsample_w;
|
||||
|
||||
|
||||
// Check
|
||||
AT_CHECK(Ki == Kw, "Number of channels in error and weights must match");
|
||||
return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight);
|
||||
@@ -142,17 +146,18 @@ torch::Tensor conv_bprop(
|
||||
|
||||
torch::Tensor conv_wgrad(
|
||||
const torch::Tensor data,
|
||||
const torch::Tensor derror
|
||||
const torch::Tensor derror,
|
||||
int64_t pad_h, int64_t pad_w
|
||||
){
|
||||
// Check
|
||||
CHECK_INPUT(data);
|
||||
CHECK_INPUT(derror);
|
||||
// Unpack data shapes
|
||||
const int32_t Ba = derror.size(0);
|
||||
const int32_t C = derror.size(1);
|
||||
const int32_t Ba = data.size(0);
|
||||
const int32_t C = data.size(1);
|
||||
const int32_t D = 1;
|
||||
const int32_t H = derror.size(2);
|
||||
const int32_t W = derror.size(3);
|
||||
const int32_t H = data.size(2);
|
||||
const int32_t W = data.size(3);
|
||||
// Unpack error shapes
|
||||
const int32_t Bb = derror.size(0);
|
||||
const int32_t K = derror.size(1);
|
||||
@@ -162,10 +167,12 @@ torch::Tensor conv_wgrad(
|
||||
// Compute M, P, Q
|
||||
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
const int32_t T = (D - M*stride_d + 1 + 2*pad_d - stride_d + 1)*upsample_d;
|
||||
const int32_t R = (H - P*stride_h + 1 + 2*pad_h - stride_h + 1)*upsample_h;
|
||||
const int32_t S = (W - Q*stride_w + 1 + 2*pad_w - stride_w + 1)*upsample_w;
|
||||
const int32_t pad_d = 0;
|
||||
const int32_t T = (D - M*stride_d + 1 + 2*pad_d + stride_d - 1)*upsample_d;
|
||||
const int32_t R = (H - P*stride_h + 1 + 2*pad_h + stride_h - 1)*upsample_h;
|
||||
const int32_t S = (W - Q*stride_w + 1 + 2*pad_w + stride_w - 1)*upsample_w;
|
||||
|
||||
|
||||
// Check
|
||||
AT_CHECK(Ba == Bb, "Number of channels in error and weights must match");
|
||||
return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror);
|
||||
|
Reference in New Issue
Block a user