23 lines
820 B
Python
23 lines
820 B
Python
![]() |
import torch
|
||
|
import triton
|
||
|
|
||
|
x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
|
||
|
w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True)
|
||
|
cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True)
|
||
|
y_target = torch.autograd.Variable(torch.randn(16, 64, 6, 6).cuda(), requires_grad=True)
|
||
|
|
||
|
def run(x, w, conv):
|
||
|
y = conv(x, w)
|
||
|
loss = (y - y_target).norm(2)
|
||
|
loss.backward()
|
||
|
return loss, y.clone(), x.grad.clone(), w.grad.clone()
|
||
|
|
||
|
ttyloss, tty, ttdx, ttdw = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, 0))
|
||
|
x.grad.zero_()
|
||
|
w.grad.zero_()
|
||
|
culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=0))
|
||
|
|
||
|
print((tty - cuy).norm(2))
|
||
|
print((ttdx - cudx).norm(2))
|
||
|
print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|