[triton/dnn/conv] merged optimizations branch

- Added forward/backward support for strided convolution
- Added support for bias
- Added support for reduction splitting
This commit is contained in:
Philippe Tillet
2019-05-28 14:02:27 -04:00
parent e526ffc62b
commit a9d078c06f
47 changed files with 732 additions and 31339 deletions

View File

@@ -2,7 +2,7 @@ import torch
from torch.nn.modules.utils import _single, _pair, _triple
import math
torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so")
torch.ops.load_library("/home/philippe/development/triton/build/examples/python/pytorch/libtorch_triton.so")
class ConvFunction(torch.autograd.Function):
@@ -37,7 +37,7 @@ class _ConvNd(torch.nn.Module):
padding, dilation, transposed, output_padding, groups, bias):
super(_ConvNd, self).__init__()
# not everything is supported by Triton
assert all(x==1 for x in stride)
assert all(x==1 or x==2 for x in stride)
assert all(x==1 for x in dilation)
assert transposed == False
assert all(x==0 for x in output_padding)
@@ -46,6 +46,7 @@ class _ConvNd(torch.nn.Module):
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.weight = torch.nn.Parameter(torch.Tensor(
in_channels, kernel_size[0], kernel_size[1], out_channels))
@@ -56,7 +57,7 @@ class _ConvNd(torch.nn.Module):
self.reset_parameters()
def forward(self, input):
return ConvFunction.apply(input, self.weight, self.bias, self.padding)
return ConvFunction.apply(input, self.weight, self.bias, self.stride, self.padding)
def reset_parameters(self):
n = self.in_channels