diff --git a/python/triton/nn/conv.py b/python/triton/nn/conv.py index f0bcf589c..ff81f7f28 100644 --- a/python/triton/nn/conv.py +++ b/python/triton/nn/conv.py @@ -64,22 +64,26 @@ class _conv2d(torch.autograd.Function): poffw = (offw + padding[1]) % stride[1] pad_h = int((padding[0] + (stride[0] - 1)*offh) / stride[0]) pad_w = int((padding[1] + (stride[1] - 1)*offw) / stride[1]) - if offh >= r or offw >= s: - dx[:, :, poffh::stride[0], poffw::stride[1]] = 0 + if poffh >= r or poffw >= s: + dx[:, :, offh::stride[0], offw::stride[1]] = 0 else: triton.ops.einsum(f'nk(h - r + pad_h)(w - s + pad_w),kcrs->nchw', dypad[:, :, :, :], - weight[:, :, offh::stride[0], offw::stride[1]], - output = dx[:, :, poffh::stride[0], poffw::stride[1]], + weight[:, :, poffh::stride[0], poffw::stride[1]], + output = dx[:, :, offh::stride[0], offw::stride[1]], mask = acc_bitmask, values = {'pad_h': pad_h, 'pad_w': pad_w}) + #if stride[0] == 2 and r == 3: + # print('dx: ', dx[0,0,0,0]) + # gradient for the weight dw = None if ctx.needs_input_grad[1]: dw = torch.empty_like(weight) triton.ops.einsum(f'nc(p*{stride[0]}+r-{padding[0]})(q*{stride[1]}+s-{padding[1]}),nkpq->kcrs', input, dy, output = dw, mask = acc_bitmask) + #print('dw: ', dw.view(-1)[0]) return dx, dw, None, None, None, None, None, None conv2d = _conv2d.apply @@ -95,7 +99,8 @@ class Conv2d(nn.Conv2d): self.acc_bitmask = acc_bitmask def forward(self, input): - #if self.kernel_size[0] == 3: + #if self.kernel_size[0] == 3 and self.stride[0] != 1: + # print(self.padding, self.stride, input.size(), self.weight.size()) # return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, @@ -113,4 +118,49 @@ def replace_conv2d(model, acc_bitmask = None): yparam.data.copy_(xparam.data) setattr(model, child_name, conv2d) else: - replace_conv2d(child, acc_bitmask) \ No newline at end of file + replace_conv2d(child, acc_bitmask) + +# initialize input +#N, C, H, W, K, RS = 16, 32, 24, 24, 64, 3 +#torch.Size([128, 64, 30, 30]) torch.Size([128, 64, 3, 3]) +#torch.Size([128, 128, 15, 15]) torch.Size([256, 128, 3, 3]) +#torch.Size([128, 256, 8, 8]) torch.Size([512, 256, 3, 3]) + +if __name__ == '__main__': + #N, C, H, W, K, RS = 128, 64, 30, 30, 128, 3 + #N, C, H, W, K, RS = 128, 128, 15, 15, 256, 3 + N, C, H, W, K, RS = 128, 256, 8, 8, 512, 3 + pad, stride = 1, 2 + torch.manual_seed(0) + x = torch.randn((N, C, H, W)).cuda() + x.requires_grad_(True) + # initialize layers + torch.manual_seed(0) + rconv2d = nn.Conv2d(C, K, RS, stride, pad, bias=False).cuda() + torch.manual_seed(0) + tconv2d = Conv2d(C, K, RS, stride, pad, bias=False).cuda() + #rconv2d.weight.data[:] = 1 + #tconv2d.weight.data[:] = 1 + ry = rconv2d(x) + ty = tconv2d(x) + # reference + dy = torch.randn(ry.size()).cuda() + #dy.data[:] = 1 + ry.backward(dy) + rdx = x.grad.clone() + rdw = rconv2d.weight.grad.clone() + x.grad.zero_() + # triton + ty.backward(dy) + tdx = x.grad.clone() + tdw = tconv2d.weight.grad.clone() + x.grad.zero_() + # print error + print((ry - ty).abs().max()) + print((rdx - tdx).abs().max()) + print((rdw - tdw).abs().max()) + #print((rdx - tdx).abs()) + + #print((rdx[0,0,:,:] - tdx[0,0,:,:])) + #print(rdx[0,0,:,:]) + #print(tdx[0,0,:,:]) \ No newline at end of file