222 lines
8.1 KiB
Python
222 lines
8.1 KiB
Python
import torch
|
|
import math
|
|
import numpy as np
|
|
from torch.nn.modules.utils import _single, _pair, _triple
|
|
from torch.distributions import categorical
|
|
|
|
torch.ops.load_library("/home/philippe/development/triton/build/examples/python/pytorch/libtorch_triton.so")
|
|
|
|
#################################
|
|
####### Convolutions ##########
|
|
#################################
|
|
|
|
class ConvFunction(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, input, weight, bias, stride, padding):
|
|
if bias is None:
|
|
bias = torch.empty(0)
|
|
ctx.save_for_backward(input, weight, bias)
|
|
ctx.stride = stride
|
|
ctx.padding = padding
|
|
output = torch.ops.triton.conv_fprop(input, weight, bias, stride[0], stride[1], padding[0], padding[1])
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input, weight, bias = ctx.saved_tensors
|
|
stride = ctx.stride
|
|
padding = ctx.padding
|
|
grad_input = grad_weight = grad_bias = None
|
|
if ctx.needs_input_grad[0]:
|
|
grad_input = torch.ops.triton.conv_bprop(grad_output, weight, bias, input.shape[2], input.shape[3], stride[0], stride[1], padding[0], padding[1])
|
|
if ctx.needs_input_grad[1]:
|
|
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output, bias, weight.shape[1], weight.shape[2], stride[0], stride[1], padding[0], padding[1])
|
|
if ctx.needs_input_grad[2]:
|
|
grad_bias = torch.sum(grad_output, (0, 2, 3))
|
|
return grad_input, grad_weight, grad_bias, None, None
|
|
|
|
|
|
class _ConvNd(torch.nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride,
|
|
padding, dilation, transposed, output_padding, groups, bias):
|
|
super(_ConvNd, self).__init__()
|
|
# not everything is supported by Triton
|
|
assert all(x==1 or x==2 for x in stride)
|
|
assert all(x==1 for x in dilation)
|
|
assert transposed == False
|
|
assert all(x==0 for x in output_padding)
|
|
assert groups == 1
|
|
# initialize
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_size = kernel_size
|
|
self.stride = stride
|
|
self.padding = padding
|
|
self.weight = torch.nn.Parameter(torch.Tensor(
|
|
in_channels, kernel_size[0], kernel_size[1], out_channels))
|
|
if bias:
|
|
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
self.reset_parameters()
|
|
|
|
def forward(self, input):
|
|
return ConvFunction.apply(input, self.weight, self.bias, self.stride, self.padding)
|
|
|
|
def reset_parameters(self):
|
|
n = self.in_channels
|
|
for k in self.kernel_size:
|
|
n *= k
|
|
stdv = 1. / math.sqrt(n)
|
|
self.weight.data.uniform_(-stdv, stdv)
|
|
if self.bias is not None:
|
|
self.bias.data.uniform_(-stdv, stdv)
|
|
|
|
|
|
|
|
class Conv2d(_ConvNd):
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
|
padding=0, dilation=1, groups=1, bias=True):
|
|
kernel_size = _pair(kernel_size)
|
|
stride = _pair(stride)
|
|
padding = _pair(padding)
|
|
dilation = _pair(dilation)
|
|
super(Conv2d, self).__init__(
|
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
|
False, _pair(0), groups, bias)
|
|
|
|
#################################
|
|
#### Shift-Convolutions #######
|
|
#################################
|
|
|
|
class ShiftConvFunction(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, input, weight, bias, stride, width, shift_h, shift_w):
|
|
if bias is None:
|
|
bias = torch.empty(0)
|
|
ctx.save_for_backward(input, weight, bias)
|
|
ctx.stride = stride
|
|
ctx.width = width
|
|
ctx.shift_h = shift_h
|
|
ctx.shift_w = shift_w
|
|
output = torch.ops.triton.shift_conv_y(input, weight, bias,
|
|
width[0], width[1],
|
|
stride[0], stride[1],
|
|
shift_h, shift_w)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, dy):
|
|
input, weight, bias = ctx.saved_tensors
|
|
stride = ctx.stride
|
|
width = ctx.width
|
|
shift_h = ctx.shift_h
|
|
shift_w = ctx.shift_w
|
|
dx = dw = dbias = None
|
|
if ctx.needs_input_grad[0]:
|
|
dx = torch.ops.triton.shift_conv_dx(dy.contiguous(), weight, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
|
|
if ctx.needs_input_grad[1]:
|
|
dw = torch.ops.triton.shift_conv_dw(dy.contiguous(), input, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
|
|
if ctx.needs_input_grad[2]:
|
|
dbias = torch.sum(dy, (1, 2, 3))
|
|
return dx, dw, dbias, None, None, None, None
|
|
|
|
|
|
class _ShiftConvNd(torch.nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, bias):
|
|
super(_ShiftConvNd, self).__init__()
|
|
# initialize
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_size = kernel_size
|
|
self.stride = stride
|
|
self.weight = torch.nn.Parameter(torch.Tensor(in_channels, out_channels))
|
|
if bias:
|
|
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
self.shift_h = self.make_shift(kernel_size[0])
|
|
self.shift_w = self.make_shift(kernel_size[1])
|
|
self.reset_parameters()
|
|
|
|
def forward(self, input):
|
|
return ShiftConvFunction.apply(input, self.weight, self.bias, self.stride,
|
|
self.kernel_size, self.shift_h, self.shift_w)
|
|
|
|
def make_shift(self, kernel_size):
|
|
if kernel_size == 3:
|
|
p = torch.Tensor([0.3, 0.4, 0.3])
|
|
elif kernel_size == 5:
|
|
p = torch.Tensor([0.1, 0.25, 0.3, 0.25, 0.1])
|
|
elif kernel_size == 7:
|
|
p = torch.Tensor([0.075, 0.1, 0.175, 0.3, 0.175, 0.1, 0.075])
|
|
elif kernel_size == 9:
|
|
p = torch.Tensor([0.05, 0.075, 0.1, 0.175, 0.2, 0.175, 0.1, 0.075, 0.05])
|
|
else:
|
|
raise RuntimeError('Unsupported kernel size')
|
|
return categorical.Categorical(p).sample((self.in_channels,)) - (kernel_size // 2)
|
|
|
|
def reset_parameters(self):
|
|
n = self.in_channels
|
|
for k in self.kernel_size:
|
|
n *= k
|
|
stdv = 1. / math.sqrt(n)
|
|
self.weight.data.uniform_(-stdv, stdv)
|
|
if self.bias is not None:
|
|
self.bias.data.uniform_(-stdv, stdv)
|
|
|
|
class ShiftConv2d(_ShiftConvNd):
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=False):
|
|
kernel_size = _pair(kernel_size)
|
|
stride = _pair(stride)
|
|
super(ShiftConv2d, self).__init__(
|
|
in_channels, out_channels, kernel_size, stride, bias)
|
|
|
|
#################################
|
|
######### BatchNorm ###########
|
|
#################################
|
|
|
|
class BatchNormFunction(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, x, gamma, beta, eps):
|
|
ctx.eps = eps
|
|
y, mean, var = torch.ops.triton.batchnorm_ymv(x, gamma, beta, eps)
|
|
ctx.save_for_backward(x, gamma, beta, mean, var)
|
|
return y
|
|
|
|
@staticmethod
|
|
def backward(ctx, dy):
|
|
eps = ctx.eps
|
|
x, gamma, beta, mean, var = ctx.saved_tensors
|
|
dx, dg, db = torch.ops.triton.batchnorm_dxdgdb(dy.contiguous(), x, gamma, mean, var, eps)
|
|
return dx, dg, db, None
|
|
|
|
|
|
class _BatchNorm(torch.nn.Module):
|
|
|
|
def __init__(self, num_features, eps=1e-5):
|
|
super(_BatchNorm, self).__init__()
|
|
self.num_features = num_features
|
|
self.eps = eps
|
|
self.weight = torch.nn.Parameter(torch.Tensor(num_features))
|
|
self.bias = torch.nn.Parameter(torch.Tensor(num_features))
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
torch.nn.init.uniform_(self.weight)
|
|
torch.nn.init.zeros_(self.bias)
|
|
|
|
def forward(self, input):
|
|
return BatchNormFunction.apply(input, self.weight, self.bias, self.eps)
|
|
|
|
class BatchNorm2d(_BatchNorm):
|
|
|
|
pass
|