ugh bug in shift-conv striding

This commit is contained in:
Philippe Tillet
2019-07-10 17:00:22 -07:00
parent f665c742f9
commit 4ca83f1935
5 changed files with 143 additions and 26 deletions

View File

@@ -1,5 +1,6 @@
import torch
import math
import numpy as np
from torch.nn.modules.utils import _single, _pair, _triple
from torch.distributions import categorical
@@ -117,11 +118,13 @@ class ShiftConvFunction(torch.autograd.Function):
shift_w = ctx.shift_w
dx = dw = dbias = None
if ctx.needs_input_grad[0]:
dx = torch.ops.triton.shift_conv_dx(dy, weight, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
dx = torch.ops.triton.shift_conv_dx(dy.contiguous(), weight, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
if ctx.needs_input_grad[1]:
dw = torch.ops.triton.shift_conv_dw(dy, input, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
dw = torch.ops.triton.shift_conv_dw(dy.contiguous(), input, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
if ctx.needs_input_grad[2]:
dbias = torch.sum(dy, (1, 2, 3))
#print('dx', ctx.needs_input_grad[0], np.isnan(dx.cpu().numpy()).any())
#print('dw', ctx.needs_input_grad[1], np.isnan(dw.cpu().numpy()).any())
return dx, dw, dbias, None, None, None, None
@@ -149,7 +152,7 @@ class _ShiftConvNd(torch.nn.Module):
def make_shift(self, kernel_size):
if kernel_size == 3:
p = torch.Tensor([0.3, 0.4, 0.3])
p = torch.Tensor([0., 1., 0.])
elif kernel_size == 5:
p = torch.Tensor([0.1, 0.25, 0.3, 0.25, 0.1])
elif kernel_size == 7: