[examples/python/pytorch] added skeleton of wrapper for shift-conv and batch-norm

This commit is contained in:
Philippe Tillet
2019-07-09 21:54:37 -07:00
parent 63b249c1d6
commit 3b89bc8463
8 changed files with 63 additions and 151 deletions

View File

@@ -4,6 +4,10 @@ import math
torch.ops.load_library("/home/philippe/development/triton/build/examples/python/pytorch/libtorch_triton.so")
#################################
####### Convolutions ##########
#################################
class ConvFunction(torch.autograd.Function):
@staticmethod
@@ -81,3 +85,54 @@ class Conv2d(_ConvNd):
super(Conv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias)
#################################
#### Shift-Convolutions #######
#################################
class ShiftConvFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, stride, width):
if bias is None:
bias = torch.empty(0)
ctx.save_for_backward(input, weight, bias)
ctx.stride = stride
ctx.width = width
output = torch.ops.triton.shift_conv_y(input, weight, bias, width[0], width[1], stride[0], stride[1])
return output
@staticmethod
def backward(ctx, dy):
input, weight, bias = ctx.saved_tensors
stride = ctx.stride
width = ctx.width
dx = dw = dbias = None
if ctx.needs_input_grad[0]:
dx = torch.ops.triton.shift_conv_dx(dy, weight, bias, width[0], width[1], stride[0], stride[1])
if ctx.needs_input_grad[1]:
dw = torch.ops.triton.shift_conv_dw(dy, input, bias, width[0], width[1], stride[0], stride[1])
if ctx.needs_input_grad[2]:
dbias = torch.sum(dy, (1, 2, 3))
return dx, dw, dbias, None, None
#################################
######### BatchNorm ###########
#################################
class BatchNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, gamma, beta, eps):
ctx.eps = eps
y, mean, var = torch.ops.triton.batchnorm_ymv(x, gamma, beta, eps)
ctx.save_for_backward(x, gamma, beta, mean, var)
return y
@staticmethod
def backward(ctx, dy):
eps = ctx.eps
x, gamma, beta, mean, var = ctx.saved_tensors
dx, dg, db = torch.ops.triton.batchnorm_dxdgdb(dy, x, gamma, mean, var, eps)
return dx, dg, db, None, None