[triton/dnn/conv] merged optimizations branch
- Added forward/backward support for strided convolution - Added support for bias - Added support for reduction splitting
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
|
||||
bias = torch.autograd.Variable(torch.randn(64).cuda(), requires_grad=True)
|
||||
w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True)
|
||||
torch.manual_seed(0)
|
||||
torch.set_printoptions(precision=4)
|
||||
|
||||
x = torch.autograd.Variable(torch.randn(64, 3, 8, 8).cuda(), requires_grad=True)
|
||||
bias = torch.autograd.Variable(torch.randn(6).cuda(), requires_grad=True)
|
||||
w = torch.autograd.Variable(torch.randn(3, 3, 3, 6).cuda(), requires_grad=True)
|
||||
cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True)
|
||||
y_target = torch.autograd.Variable(torch.randn(16, 64, 6, 6).cuda(), requires_grad=True)
|
||||
y_target = torch.autograd.Variable(torch.randn(64, 6, 8, 8).cuda(), requires_grad=True)
|
||||
|
||||
def run(x, w, conv):
|
||||
y = conv(x, w)
|
||||
@@ -13,13 +16,14 @@ def run(x, w, conv):
|
||||
loss.backward()
|
||||
return loss, y.clone(), x.grad.clone(), w.grad.clone(), bias.grad.clone()
|
||||
|
||||
ttyloss, tty, ttdx, ttdw, ttbias = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, bias, (1,1), (0,0)))
|
||||
ttyloss, tty, ttdx, ttdw, ttbias = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, bias, (1,1), (1,1)))
|
||||
x.grad.zero_()
|
||||
w.grad.zero_()
|
||||
bias.grad.zero_()
|
||||
culoss, cuy, cudx, cudw, cubias = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, bias=bias, stride=1, padding=0))
|
||||
culoss, cuy, cudx, cudw, cubias = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, bias=bias, stride=1, padding=1))
|
||||
|
||||
print(ttdx[0,0,:,:], cudx[0,0,:,:])
|
||||
print(ttdx[0,0,:,:])
|
||||
print(cudx[0,0,:,:])
|
||||
print((tty - cuy).norm(2))
|
||||
print((ttdx - cudx).norm(2))
|
||||
print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|
||||
|
Reference in New Issue
Block a user