ugh bug in shift-conv striding
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
from torch.nn.modules.utils import _single, _pair, _triple
|
||||
from torch.distributions import categorical
|
||||
|
||||
@@ -117,11 +118,13 @@ class ShiftConvFunction(torch.autograd.Function):
|
||||
shift_w = ctx.shift_w
|
||||
dx = dw = dbias = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
dx = torch.ops.triton.shift_conv_dx(dy, weight, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
|
||||
dx = torch.ops.triton.shift_conv_dx(dy.contiguous(), weight, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
|
||||
if ctx.needs_input_grad[1]:
|
||||
dw = torch.ops.triton.shift_conv_dw(dy, input, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
|
||||
dw = torch.ops.triton.shift_conv_dw(dy.contiguous(), input, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
|
||||
if ctx.needs_input_grad[2]:
|
||||
dbias = torch.sum(dy, (1, 2, 3))
|
||||
#print('dx', ctx.needs_input_grad[0], np.isnan(dx.cpu().numpy()).any())
|
||||
#print('dw', ctx.needs_input_grad[1], np.isnan(dw.cpu().numpy()).any())
|
||||
return dx, dw, dbias, None, None, None, None
|
||||
|
||||
|
||||
@@ -149,7 +152,7 @@ class _ShiftConvNd(torch.nn.Module):
|
||||
|
||||
def make_shift(self, kernel_size):
|
||||
if kernel_size == 3:
|
||||
p = torch.Tensor([0.3, 0.4, 0.3])
|
||||
p = torch.Tensor([0., 1., 0.])
|
||||
elif kernel_size == 5:
|
||||
p = torch.Tensor([0.1, 0.25, 0.3, 0.25, 0.1])
|
||||
elif kernel_size == 7:
|
||||
|
Reference in New Issue
Block a user