[triton/python/conv]: Added cache for compiled kernels
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import time
|
||||
torch.manual_seed(0)
|
||||
|
||||
class TritonConv(torch.autograd.Function):
|
||||
@@ -14,9 +15,9 @@ class TritonConv(torch.autograd.Function):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input = grad_weight = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = torch.ops.triton.conv_bprop(grad_output.contiguous(), weight)
|
||||
grad_input = torch.ops.triton.conv_bprop(grad_output, weight)
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output.contiguous())
|
||||
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output)
|
||||
return grad_input, grad_weight
|
||||
|
||||
|
||||
@@ -38,6 +39,7 @@ x.grad.zero_()
|
||||
w.grad.zero_()
|
||||
culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=1))
|
||||
|
||||
|
||||
print((tty - cuy).norm(2))
|
||||
print((ttdx - cudx).norm(2))
|
||||
print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|
||||
|
Reference in New Issue
Block a user