ugh bug in shift-conv striding
This commit is contained in:
@@ -6,10 +6,97 @@ import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
import triton
|
||||
from torch.utils.cpp_extension import load
|
||||
from torch.distributions import categorical
|
||||
|
||||
class Net(nn.Module):
|
||||
shift_cuda = load(
|
||||
'shift_cuda', ['/home/philippe/development/shiftnet/kernels/shift_cuda.cpp',
|
||||
'/home/philippe/development/shiftnet/kernels/shift_cuda_kernel.cu'], extra_cflags=['-O3'])
|
||||
|
||||
class shift(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shift):
|
||||
ctx.save_for_backward(shift)
|
||||
return shift_cuda.forward(x, shift)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
shift, = ctx.saved_tensors
|
||||
grad_output = shift_cuda.backward(grad_output, shift)
|
||||
|
||||
return grad_output, None
|
||||
|
||||
|
||||
class Shift(nn.Module):
|
||||
def __init__(self, in_channels, kernel_size):
|
||||
super(Shift, self).__init__()
|
||||
self.channels = in_channels
|
||||
self.kernel_size = kernel_size
|
||||
if kernel_size == 3:
|
||||
p = torch.Tensor([0., 1., 0.])
|
||||
elif kernel_size == 5:
|
||||
p = torch.Tensor([0.1, 0.25, 0.3, 0.25, 0.1])
|
||||
elif kernel_size == 7:
|
||||
p = torch.Tensor([0.075, 0.1, 0.175, 0.3, 0.175, 0.1, 0.075])
|
||||
elif kernel_size == 9:
|
||||
p = torch.Tensor([0.05, 0.075, 0.1, 0.175, 0.2, 0.175, 0.1, 0.075, 0.05])
|
||||
else:
|
||||
raise RuntimeError('Unsupported kernel size')
|
||||
shift_t = categorical.Categorical(p).sample((in_channels, 2)) - (kernel_size // 2)
|
||||
self.register_buffer('shift_t', shift_t.int())
|
||||
|
||||
def forward(self, x):
|
||||
if x.is_cuda:
|
||||
return shift.apply(x, self.shift_t)
|
||||
else:
|
||||
print('Shift only supports GPU for now..')
|
||||
assert False
|
||||
|
||||
def extra_repr(self):
|
||||
s = ('{channels}, kernel_size={kernel_size}')
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
|
||||
def ShiftConv2d(in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
Shift(in_planes, kernel_size),
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
|
||||
padding=0, groups=groups, bias=False)
|
||||
)
|
||||
|
||||
|
||||
class NetReference(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
super(NetReference, self).__init__()
|
||||
#self.conv1 = ShiftConv2d(1, 32, 3, 2)
|
||||
self.conv1 = triton.ShiftConv2d(1, 32, 3, 2)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
#self.conv2a = ShiftConv2d(32, 32, 3, 1)
|
||||
self.conv2b = triton.ShiftConv2d(32, 32, 3, 2)
|
||||
#self.conv2b = ShiftConv2d(32, 32, 3, 2)
|
||||
self.bn2 = nn.BatchNorm2d(32)
|
||||
self.fc1 = nn.Linear(32*7*7, 500)
|
||||
self.fc2 = nn.Linear(500, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(1, 2, 3, 0).contiguous()
|
||||
x = self.conv1(x)
|
||||
x = x.permute(3, 0, 1, 2).contiguous()
|
||||
x = self.bn1(x)
|
||||
x = F.relu(x)
|
||||
x = x.permute(1, 2, 3, 0).contiguous()
|
||||
x = self.conv2b(x)
|
||||
x = x.permute(3, 0, 1, 2).contiguous()
|
||||
x = self.bn2(x)
|
||||
x = F.relu(x)
|
||||
x = x.view(-1, 32*7*7)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
class NetTriton(nn.Module):
|
||||
def __init__(self):
|
||||
super(NetTriton, self).__init__()
|
||||
self.conv1 = triton.ShiftConv2d(1, 32, 3, 2)
|
||||
self.bn1 = triton.BatchNorm2d(32)
|
||||
self.conv2 = triton.ShiftConv2d(32, 64, 3, 2)
|
||||
@@ -23,6 +110,7 @@ class Net(nn.Module):
|
||||
x = self.bn1(x)
|
||||
x = F.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = F.relu(x)
|
||||
x = x.permute(3, 0, 1, 2).contiguous()
|
||||
x = x.view(-1, 64*7*7)
|
||||
@@ -30,6 +118,8 @@ class Net(nn.Module):
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
Net = NetReference()
|
||||
|
||||
def train(args, model, device, train_loader, optimizer, epoch):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
@@ -107,7 +197,7 @@ def main():
|
||||
batch_size=args.test_batch_size, shuffle=True, **kwargs)
|
||||
|
||||
|
||||
model = Net().to(device)
|
||||
model = Net.to(device)
|
||||
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
|
Reference in New Issue
Block a user