[TRITON][NN][CONV] Renamed input -> x to not modify built-in functions
This commit is contained in:
committed by
Philippe Tillet
parent
420e36a038
commit
926acc2e28
@@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|||||||
class _conv2d(torch.autograd.Function):
|
class _conv2d(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input, weight, bias,
|
def forward(ctx, x, weight, bias,
|
||||||
stride, padding, dilation, groups,
|
stride, padding, dilation, groups,
|
||||||
acc_bitmask):
|
acc_bitmask):
|
||||||
assert dilation == (1, 1)
|
assert dilation == (1, 1)
|
||||||
@@ -14,25 +14,25 @@ class _conv2d(torch.autograd.Function):
|
|||||||
assert bias == None
|
assert bias == None
|
||||||
pad_h, pad_w = padding
|
pad_h, pad_w = padding
|
||||||
stride_h, stride_w = stride
|
stride_h, stride_w = stride
|
||||||
n, c, h, w = input.size()
|
n, c, h, w = x.size()
|
||||||
k, c, r, s = weight.size()
|
k, c, r, s = weight.size()
|
||||||
# allocate output
|
# allocate output
|
||||||
p = (h + 2*padding[0] - r)//stride[0] + 1
|
p = (h + 2*padding[0] - r)//stride[0] + 1
|
||||||
q = (w + 2*padding[1] - s)//stride[1] + 1
|
q = (w + 2*padding[1] - s)//stride[1] + 1
|
||||||
output = torch.empty((n, k, p, q), dtype=input.dtype, device=input.device)
|
output = torch.empty((n, k, p, q), dtype=x.dtype, device=x.device)
|
||||||
# padding
|
# padding
|
||||||
if pad_h or pad_w:
|
if pad_h or pad_w:
|
||||||
input = triton.ops._einsum.pad(input, [pad_w, pad_w, pad_h, pad_h])
|
x = triton.ops._einsum.pad(x, [pad_w, pad_w, pad_h, pad_h])
|
||||||
# convolution
|
# convolution
|
||||||
triton.ops.einsum(f'nc(h*stride_h + r - pad_h)(w*stride_w + s - pad_w),kcrs->nkhw',
|
triton.ops.einsum(f'nc(h*stride_h + r - pad_h)(w*stride_w + s - pad_w),kcrs->nkhw',
|
||||||
input, weight, mask=acc_bitmask,
|
x, weight, mask=acc_bitmask,
|
||||||
output=output,
|
output=output,
|
||||||
values = {'pad_h': pad_h,
|
values = {'pad_h': pad_h,
|
||||||
'stride_h': stride_h,
|
'stride_h': stride_h,
|
||||||
'pad_w': pad_w,
|
'pad_w': pad_w,
|
||||||
'stride_w': stride_w})
|
'stride_w': stride_w})
|
||||||
# prepare backprop
|
# prepare backprop
|
||||||
ctx.save_for_backward(input, weight)
|
ctx.save_for_backward(x, weight)
|
||||||
ctx.stride = stride
|
ctx.stride = stride
|
||||||
ctx.padding = padding
|
ctx.padding = padding
|
||||||
ctx.acc_bitmask = acc_bitmask
|
ctx.acc_bitmask = acc_bitmask
|
||||||
@@ -42,7 +42,7 @@ class _conv2d(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, dy):
|
def backward(ctx, dy):
|
||||||
# retrieve contextual information
|
# retrieve contextual information
|
||||||
input, weight = ctx.saved_tensors
|
x, weight = ctx.saved_tensors
|
||||||
stride = ctx.stride
|
stride = ctx.stride
|
||||||
padding = ctx.padding
|
padding = ctx.padding
|
||||||
acc_bitmask = ctx.acc_bitmask
|
acc_bitmask = ctx.acc_bitmask
|
||||||
@@ -51,13 +51,13 @@ class _conv2d(torch.autograd.Function):
|
|||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
# dy must be padded
|
# dy must be padded
|
||||||
n, k, p, q = dy.size()
|
n, k, p, q = dy.size()
|
||||||
n, c, h, w = input.size()
|
n, c, h, w = x.size()
|
||||||
k, c, r, s = weight.size()
|
k, c, r, s = weight.size()
|
||||||
dypad = triton.ops._einsum.pad(dy, [4, 4, 4, 4])
|
dypad = triton.ops._einsum.pad(dy, [4, 4, 4, 4])
|
||||||
# have to be careful here
|
# have to be careful here
|
||||||
# the gradient of strided conv is a conv over a sparse image
|
# the gradient of strided conv is a conv over a sparse image
|
||||||
# which can be decomposed as a set of smaller convs
|
# which can be decomposed as a set of smaller convs
|
||||||
dx = torch.empty_like(input)
|
dx = torch.empty_like(x)
|
||||||
for offh in range(stride[0]):
|
for offh in range(stride[0]):
|
||||||
for offw in range(stride[1]):
|
for offw in range(stride[1]):
|
||||||
poffh = (offh + padding[0]) % stride[0]
|
poffh = (offh + padding[0]) % stride[0]
|
||||||
@@ -74,15 +74,13 @@ class _conv2d(torch.autograd.Function):
|
|||||||
mask = acc_bitmask,
|
mask = acc_bitmask,
|
||||||
values = {'pad_h': pad_h,
|
values = {'pad_h': pad_h,
|
||||||
'pad_w': pad_w})
|
'pad_w': pad_w})
|
||||||
#if stride[0] == 2 and r == 3:
|
|
||||||
# print('dx: ', dx[0,0,0,0])
|
|
||||||
|
|
||||||
# gradient for the weight
|
# gradient for the weight
|
||||||
dw = None
|
dw = None
|
||||||
if ctx.needs_input_grad[1]:
|
if ctx.needs_input_grad[1]:
|
||||||
dw = torch.empty_like(weight)
|
dw = torch.empty_like(weight)
|
||||||
triton.ops.einsum(f'nc(p*{stride[0]}+r-{padding[0]})(q*{stride[1]}+s-{padding[1]}),nkpq->kcrs',
|
triton.ops.einsum(f'nc(p*{stride[0]}+r-{padding[0]})(q*{stride[1]}+s-{padding[1]}),nkpq->kcrs',
|
||||||
input, dy, output = dw, mask = acc_bitmask)
|
x, dy, output = dw, mask = acc_bitmask)
|
||||||
#print('dw: ', dw.view(-1)[0])
|
#print('dw: ', dw.view(-1)[0])
|
||||||
return dx, dw, None, None, None, None, None, None
|
return dx, dw, None, None, None, None, None, None
|
||||||
conv2d = _conv2d.apply
|
conv2d = _conv2d.apply
|
||||||
@@ -100,7 +98,7 @@ class Conv2d(nn.Conv2d):
|
|||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
#if self.kernel_size[0] == 3 and self.stride[0] != 1:
|
#if self.kernel_size[0] == 3 and self.stride[0] != 1:
|
||||||
# print(self.padding, self.stride, input.size(), self.weight.size())
|
#print(self.padding, self.stride, input.size(), self.weight.size())
|
||||||
# return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
# return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
return conv2d(input, self.weight, self.bias, self.stride,
|
return conv2d(input, self.weight, self.bias, self.stride,
|
||||||
self.padding, self.dilation, self.groups,
|
self.padding, self.dilation, self.groups,
|
||||||
@@ -127,13 +125,14 @@ def replace_conv2d(model, acc_bitmask = None):
|
|||||||
#torch.Size([128, 256, 8, 8]) torch.Size([512, 256, 3, 3])
|
#torch.Size([128, 256, 8, 8]) torch.Size([512, 256, 3, 3])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
#N, C, H, W, K, RS = 128, 64, 30, 30, 128, 3
|
N, C, H, W, K, RS = 128, 64, 30, 30, 128, 1
|
||||||
#N, C, H, W, K, RS = 128, 128, 15, 15, 256, 3
|
#N, C, H, W, K, RS = 128, 128, 15, 15, 256, 3
|
||||||
N, C, H, W, K, RS = 128, 256, 8, 8, 512, 3
|
#N, C, H, W, K, RS = 128, 256, 8, 8, 512, 3
|
||||||
pad, stride = 1, 2
|
pad, stride = 0, 1
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
x = torch.randn((N, C, H, W)).cuda()
|
x = torch.randn((N, C, H, W)).cuda()
|
||||||
x.requires_grad_(True)
|
x.requires_grad_(True)
|
||||||
|
#x.data[:] = 1
|
||||||
# initialize layers
|
# initialize layers
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
rconv2d = nn.Conv2d(C, K, RS, stride, pad, bias=False).cuda()
|
rconv2d = nn.Conv2d(C, K, RS, stride, pad, bias=False).cuda()
|
||||||
@@ -156,9 +155,10 @@ if __name__ == '__main__':
|
|||||||
tdw = tconv2d.weight.grad.clone()
|
tdw = tconv2d.weight.grad.clone()
|
||||||
x.grad.zero_()
|
x.grad.zero_()
|
||||||
# print error
|
# print error
|
||||||
print((ry - ty).abs().max())
|
diff = lambda x, y: (x - y).abs().max()
|
||||||
print((rdx - tdx).abs().max())
|
print(diff(ry, ty))
|
||||||
print((rdw - tdw).abs().max())
|
print(diff(rdx, tdx))
|
||||||
|
print(diff(rdw, tdw))
|
||||||
#print((rdx - tdx).abs())
|
#print((rdx - tdx).abs())
|
||||||
|
|
||||||
#print((rdx[0,0,:,:] - tdx[0,0,:,:]))
|
#print((rdx[0,0,:,:] - tdx[0,0,:,:]))
|
||||||
|
Reference in New Issue
Block a user