[examples/python/pytorch] added skeleton of wrapper for shift-conv and batch-norm
This commit is contained in:
@@ -4,6 +4,10 @@ import math
|
||||
|
||||
torch.ops.load_library("/home/philippe/development/triton/build/examples/python/pytorch/libtorch_triton.so")
|
||||
|
||||
#################################
|
||||
####### Convolutions ##########
|
||||
#################################
|
||||
|
||||
class ConvFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@@ -81,3 +85,54 @@ class Conv2d(_ConvNd):
|
||||
super(Conv2d, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
False, _pair(0), groups, bias)
|
||||
|
||||
#################################
|
||||
#### Shift-Convolutions #######
|
||||
#################################
|
||||
|
||||
class ShiftConvFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias, stride, width):
|
||||
if bias is None:
|
||||
bias = torch.empty(0)
|
||||
ctx.save_for_backward(input, weight, bias)
|
||||
ctx.stride = stride
|
||||
ctx.width = width
|
||||
output = torch.ops.triton.shift_conv_y(input, weight, bias, width[0], width[1], stride[0], stride[1])
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
input, weight, bias = ctx.saved_tensors
|
||||
stride = ctx.stride
|
||||
width = ctx.width
|
||||
dx = dw = dbias = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
dx = torch.ops.triton.shift_conv_dx(dy, weight, bias, width[0], width[1], stride[0], stride[1])
|
||||
if ctx.needs_input_grad[1]:
|
||||
dw = torch.ops.triton.shift_conv_dw(dy, input, bias, width[0], width[1], stride[0], stride[1])
|
||||
if ctx.needs_input_grad[2]:
|
||||
dbias = torch.sum(dy, (1, 2, 3))
|
||||
return dx, dw, dbias, None, None
|
||||
|
||||
|
||||
#################################
|
||||
######### BatchNorm ###########
|
||||
#################################
|
||||
|
||||
class BatchNormFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, gamma, beta, eps):
|
||||
ctx.eps = eps
|
||||
y, mean, var = torch.ops.triton.batchnorm_ymv(x, gamma, beta, eps)
|
||||
ctx.save_for_backward(x, gamma, beta, mean, var)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
eps = ctx.eps
|
||||
x, gamma, beta, mean, var = ctx.saved_tensors
|
||||
dx, dg, db = torch.ops.triton.batchnorm_dxdgdb(dy, x, gamma, mean, var, eps)
|
||||
return dx, dg, db, None, None
|
||||
|
Reference in New Issue
Block a user