[triton/python/conv]: Added cache for compiled kernels

This commit is contained in:
Philippe Tillet
2019-05-18 11:51:49 -04:00
parent 600aef72d5
commit b2b55c52c9
10 changed files with 210 additions and 516 deletions

View File

@@ -1,4 +1,5 @@
import torch
import time
torch.manual_seed(0)
class TritonConv(torch.autograd.Function):
@@ -14,9 +15,9 @@ class TritonConv(torch.autograd.Function):
input, weight = ctx.saved_tensors
grad_input = grad_weight = None
if ctx.needs_input_grad[0]:
grad_input = torch.ops.triton.conv_bprop(grad_output.contiguous(), weight)
grad_input = torch.ops.triton.conv_bprop(grad_output, weight)
if ctx.needs_input_grad[1]:
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output.contiguous())
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output)
return grad_input, grad_weight
@@ -38,6 +39,7 @@ x.grad.zero_()
w.grad.zero_()
culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=1))
print((tty - cuy).norm(2))
print((ttdx - cudx).norm(2))
print((ttdw.permute(3,0,1,2) - cudw).norm(2))