[PYTHON][NN][CONV] Fixed typo in dx computation
This commit is contained in:
committed by
Philippe Tillet
parent
01154f24db
commit
420e36a038
@@ -64,22 +64,26 @@ class _conv2d(torch.autograd.Function):
|
|||||||
poffw = (offw + padding[1]) % stride[1]
|
poffw = (offw + padding[1]) % stride[1]
|
||||||
pad_h = int((padding[0] + (stride[0] - 1)*offh) / stride[0])
|
pad_h = int((padding[0] + (stride[0] - 1)*offh) / stride[0])
|
||||||
pad_w = int((padding[1] + (stride[1] - 1)*offw) / stride[1])
|
pad_w = int((padding[1] + (stride[1] - 1)*offw) / stride[1])
|
||||||
if offh >= r or offw >= s:
|
if poffh >= r or poffw >= s:
|
||||||
dx[:, :, poffh::stride[0], poffw::stride[1]] = 0
|
dx[:, :, offh::stride[0], offw::stride[1]] = 0
|
||||||
else:
|
else:
|
||||||
triton.ops.einsum(f'nk(h - r + pad_h)(w - s + pad_w),kcrs->nchw',
|
triton.ops.einsum(f'nk(h - r + pad_h)(w - s + pad_w),kcrs->nchw',
|
||||||
dypad[:, :, :, :],
|
dypad[:, :, :, :],
|
||||||
weight[:, :, offh::stride[0], offw::stride[1]],
|
weight[:, :, poffh::stride[0], poffw::stride[1]],
|
||||||
output = dx[:, :, poffh::stride[0], poffw::stride[1]],
|
output = dx[:, :, offh::stride[0], offw::stride[1]],
|
||||||
mask = acc_bitmask,
|
mask = acc_bitmask,
|
||||||
values = {'pad_h': pad_h,
|
values = {'pad_h': pad_h,
|
||||||
'pad_w': pad_w})
|
'pad_w': pad_w})
|
||||||
|
#if stride[0] == 2 and r == 3:
|
||||||
|
# print('dx: ', dx[0,0,0,0])
|
||||||
|
|
||||||
# gradient for the weight
|
# gradient for the weight
|
||||||
dw = None
|
dw = None
|
||||||
if ctx.needs_input_grad[1]:
|
if ctx.needs_input_grad[1]:
|
||||||
dw = torch.empty_like(weight)
|
dw = torch.empty_like(weight)
|
||||||
triton.ops.einsum(f'nc(p*{stride[0]}+r-{padding[0]})(q*{stride[1]}+s-{padding[1]}),nkpq->kcrs',
|
triton.ops.einsum(f'nc(p*{stride[0]}+r-{padding[0]})(q*{stride[1]}+s-{padding[1]}),nkpq->kcrs',
|
||||||
input, dy, output = dw, mask = acc_bitmask)
|
input, dy, output = dw, mask = acc_bitmask)
|
||||||
|
#print('dw: ', dw.view(-1)[0])
|
||||||
return dx, dw, None, None, None, None, None, None
|
return dx, dw, None, None, None, None, None, None
|
||||||
conv2d = _conv2d.apply
|
conv2d = _conv2d.apply
|
||||||
|
|
||||||
@@ -95,7 +99,8 @@ class Conv2d(nn.Conv2d):
|
|||||||
self.acc_bitmask = acc_bitmask
|
self.acc_bitmask = acc_bitmask
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
#if self.kernel_size[0] == 3:
|
#if self.kernel_size[0] == 3 and self.stride[0] != 1:
|
||||||
|
# print(self.padding, self.stride, input.size(), self.weight.size())
|
||||||
# return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
# return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
return conv2d(input, self.weight, self.bias, self.stride,
|
return conv2d(input, self.weight, self.bias, self.stride,
|
||||||
self.padding, self.dilation, self.groups,
|
self.padding, self.dilation, self.groups,
|
||||||
@@ -114,3 +119,48 @@ def replace_conv2d(model, acc_bitmask = None):
|
|||||||
setattr(model, child_name, conv2d)
|
setattr(model, child_name, conv2d)
|
||||||
else:
|
else:
|
||||||
replace_conv2d(child, acc_bitmask)
|
replace_conv2d(child, acc_bitmask)
|
||||||
|
|
||||||
|
# initialize input
|
||||||
|
#N, C, H, W, K, RS = 16, 32, 24, 24, 64, 3
|
||||||
|
#torch.Size([128, 64, 30, 30]) torch.Size([128, 64, 3, 3])
|
||||||
|
#torch.Size([128, 128, 15, 15]) torch.Size([256, 128, 3, 3])
|
||||||
|
#torch.Size([128, 256, 8, 8]) torch.Size([512, 256, 3, 3])
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
#N, C, H, W, K, RS = 128, 64, 30, 30, 128, 3
|
||||||
|
#N, C, H, W, K, RS = 128, 128, 15, 15, 256, 3
|
||||||
|
N, C, H, W, K, RS = 128, 256, 8, 8, 512, 3
|
||||||
|
pad, stride = 1, 2
|
||||||
|
torch.manual_seed(0)
|
||||||
|
x = torch.randn((N, C, H, W)).cuda()
|
||||||
|
x.requires_grad_(True)
|
||||||
|
# initialize layers
|
||||||
|
torch.manual_seed(0)
|
||||||
|
rconv2d = nn.Conv2d(C, K, RS, stride, pad, bias=False).cuda()
|
||||||
|
torch.manual_seed(0)
|
||||||
|
tconv2d = Conv2d(C, K, RS, stride, pad, bias=False).cuda()
|
||||||
|
#rconv2d.weight.data[:] = 1
|
||||||
|
#tconv2d.weight.data[:] = 1
|
||||||
|
ry = rconv2d(x)
|
||||||
|
ty = tconv2d(x)
|
||||||
|
# reference
|
||||||
|
dy = torch.randn(ry.size()).cuda()
|
||||||
|
#dy.data[:] = 1
|
||||||
|
ry.backward(dy)
|
||||||
|
rdx = x.grad.clone()
|
||||||
|
rdw = rconv2d.weight.grad.clone()
|
||||||
|
x.grad.zero_()
|
||||||
|
# triton
|
||||||
|
ty.backward(dy)
|
||||||
|
tdx = x.grad.clone()
|
||||||
|
tdw = tconv2d.weight.grad.clone()
|
||||||
|
x.grad.zero_()
|
||||||
|
# print error
|
||||||
|
print((ry - ty).abs().max())
|
||||||
|
print((rdx - tdx).abs().max())
|
||||||
|
print((rdw - tdw).abs().max())
|
||||||
|
#print((rdx - tdx).abs())
|
||||||
|
|
||||||
|
#print((rdx[0,0,:,:] - tdx[0,0,:,:]))
|
||||||
|
#print(rdx[0,0,:,:])
|
||||||
|
#print(tdx[0,0,:,:])
|
Reference in New Issue
Block a user