Files
triton/examples/python/pytorch/triton.py
2019-05-19 01:31:08 -04:00

47 lines
1.6 KiB
Python

import torch
import math
torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so")
class ConvFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, padding):
ctx.save_for_backward(input, weight)
ctx.padding = padding
output = torch.ops.triton.conv_fprop(input, weight, padding, padding)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
padding = ctx.padding
grad_input = grad_weight = None
if ctx.needs_input_grad[0]:
grad_input = torch.ops.triton.conv_bprop(grad_output, weight, padding, padding)
if ctx.needs_input_grad[1]:
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output, padding, padding)
return grad_input, grad_weight, None
class Conv2d(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding = 0):
super(Conv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.weight = torch.nn.Parameter(torch.Tensor(
in_channels, kernel_size[0], kernel_size[1], out_channels))
self.reset_parameters()
def forward(self, input):
return ConvFunction.apply(input, self.weight, self.padding)
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)