testing a simple shiftnet

This commit is contained in:
Philippe Tillet
2019-07-10 13:33:08 -07:00
parent 3b89bc8463
commit f665c742f9
6 changed files with 261 additions and 31 deletions

View File

@@ -1,6 +1,7 @@
import torch
from torch.nn.modules.utils import _single, _pair, _triple
import math
from torch.nn.modules.utils import _single, _pair, _triple
from torch.distributions import categorical
torch.ops.load_library("/home/philippe/development/triton/build/examples/python/pytorch/libtorch_triton.so")
@@ -93,13 +94,18 @@ class Conv2d(_ConvNd):
class ShiftConvFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, stride, width):
def forward(ctx, input, weight, bias, stride, width, shift_h, shift_w):
if bias is None:
bias = torch.empty(0)
ctx.save_for_backward(input, weight, bias)
ctx.stride = stride
ctx.width = width
output = torch.ops.triton.shift_conv_y(input, weight, bias, width[0], width[1], stride[0], stride[1])
ctx.shift_h = shift_h
ctx.shift_w = shift_w
output = torch.ops.triton.shift_conv_y(input, weight, bias,
width[0], width[1],
stride[0], stride[1],
shift_h, shift_w)
return output
@staticmethod
@@ -107,16 +113,70 @@ class ShiftConvFunction(torch.autograd.Function):
input, weight, bias = ctx.saved_tensors
stride = ctx.stride
width = ctx.width
shift_h = ctx.shift_h
shift_w = ctx.shift_w
dx = dw = dbias = None
if ctx.needs_input_grad[0]:
dx = torch.ops.triton.shift_conv_dx(dy, weight, bias, width[0], width[1], stride[0], stride[1])
dx = torch.ops.triton.shift_conv_dx(dy, weight, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
if ctx.needs_input_grad[1]:
dw = torch.ops.triton.shift_conv_dw(dy, input, bias, width[0], width[1], stride[0], stride[1])
dw = torch.ops.triton.shift_conv_dw(dy, input, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
if ctx.needs_input_grad[2]:
dbias = torch.sum(dy, (1, 2, 3))
return dx, dw, dbias, None, None
return dx, dw, dbias, None, None, None, None
class _ShiftConvNd(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, bias):
super(_ShiftConvNd, self).__init__()
# initialize
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.weight = torch.nn.Parameter(torch.Tensor(in_channels, out_channels))
if bias:
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.shift_h = self.make_shift(kernel_size[0])
self.shift_w = self.make_shift(kernel_size[1])
self.reset_parameters()
def forward(self, input):
return ShiftConvFunction.apply(input, self.weight, self.bias, self.stride,
self.kernel_size, self.shift_h, self.shift_w)
def make_shift(self, kernel_size):
if kernel_size == 3:
p = torch.Tensor([0.3, 0.4, 0.3])
elif kernel_size == 5:
p = torch.Tensor([0.1, 0.25, 0.3, 0.25, 0.1])
elif kernel_size == 7:
p = torch.Tensor([0.075, 0.1, 0.175, 0.3, 0.175, 0.1, 0.075])
elif kernel_size == 9:
p = torch.Tensor([0.05, 0.075, 0.1, 0.175, 0.2, 0.175, 0.1, 0.075, 0.05])
else:
raise RuntimeError('Unsupported kernel size')
return categorical.Categorical(p).sample((self.in_channels,)) - (kernel_size // 2)
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
class ShiftConv2d(_ShiftConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=False):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
super(ShiftConv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, bias)
#################################
######### BatchNorm ###########
#################################
@@ -134,5 +194,27 @@ class BatchNormFunction(torch.autograd.Function):
def backward(ctx, dy):
eps = ctx.eps
x, gamma, beta, mean, var = ctx.saved_tensors
dx, dg, db = torch.ops.triton.batchnorm_dxdgdb(dy, x, gamma, mean, var, eps)
return dx, dg, db, None, None
dx, dg, db = torch.ops.triton.batchnorm_dxdgdb(dy.contiguous(), x, gamma, mean, var, eps)
return dx, dg, db, None
class _BatchNorm(torch.nn.Module):
def __init__(self, num_features, eps=1e-5):
super(_BatchNorm, self).__init__()
self.num_features = num_features
self.eps = eps
self.weight = torch.nn.Parameter(torch.Tensor(num_features))
self.bias = torch.nn.Parameter(torch.Tensor(num_features))
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.uniform_(self.weight)
torch.nn.init.zeros_(self.bias)
def forward(self, input):
return BatchNormFunction.apply(input, self.weight, self.bias, self.eps)
class BatchNorm2d(_BatchNorm):
pass