From 926acc2e2826abab058acdc1293b4b43ea797a52 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 25 Feb 2020 10:56:39 -0500 Subject: [PATCH] [TRITON][NN][CONV] Renamed input -> x to not modify built-in functions --- python/triton/nn/conv.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/python/triton/nn/conv.py b/python/triton/nn/conv.py index ff81f7f28..8a6c744d8 100644 --- a/python/triton/nn/conv.py +++ b/python/triton/nn/conv.py @@ -6,7 +6,7 @@ import torch.nn.functional as F class _conv2d(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight, bias, + def forward(ctx, x, weight, bias, stride, padding, dilation, groups, acc_bitmask): assert dilation == (1, 1) @@ -14,25 +14,25 @@ class _conv2d(torch.autograd.Function): assert bias == None pad_h, pad_w = padding stride_h, stride_w = stride - n, c, h, w = input.size() + n, c, h, w = x.size() k, c, r, s = weight.size() # allocate output p = (h + 2*padding[0] - r)//stride[0] + 1 q = (w + 2*padding[1] - s)//stride[1] + 1 - output = torch.empty((n, k, p, q), dtype=input.dtype, device=input.device) + output = torch.empty((n, k, p, q), dtype=x.dtype, device=x.device) # padding if pad_h or pad_w: - input = triton.ops._einsum.pad(input, [pad_w, pad_w, pad_h, pad_h]) + x = triton.ops._einsum.pad(x, [pad_w, pad_w, pad_h, pad_h]) # convolution triton.ops.einsum(f'nc(h*stride_h + r - pad_h)(w*stride_w + s - pad_w),kcrs->nkhw', - input, weight, mask=acc_bitmask, + x, weight, mask=acc_bitmask, output=output, values = {'pad_h': pad_h, 'stride_h': stride_h, 'pad_w': pad_w, 'stride_w': stride_w}) # prepare backprop - ctx.save_for_backward(input, weight) + ctx.save_for_backward(x, weight) ctx.stride = stride ctx.padding = padding ctx.acc_bitmask = acc_bitmask @@ -42,7 +42,7 @@ class _conv2d(torch.autograd.Function): @staticmethod def backward(ctx, dy): # retrieve contextual information - input, weight = ctx.saved_tensors + x, weight = ctx.saved_tensors stride = ctx.stride padding = ctx.padding acc_bitmask = ctx.acc_bitmask @@ -51,13 +51,13 @@ class _conv2d(torch.autograd.Function): if ctx.needs_input_grad[0]: # dy must be padded n, k, p, q = dy.size() - n, c, h, w = input.size() + n, c, h, w = x.size() k, c, r, s = weight.size() dypad = triton.ops._einsum.pad(dy, [4, 4, 4, 4]) # have to be careful here # the gradient of strided conv is a conv over a sparse image # which can be decomposed as a set of smaller convs - dx = torch.empty_like(input) + dx = torch.empty_like(x) for offh in range(stride[0]): for offw in range(stride[1]): poffh = (offh + padding[0]) % stride[0] @@ -74,15 +74,13 @@ class _conv2d(torch.autograd.Function): mask = acc_bitmask, values = {'pad_h': pad_h, 'pad_w': pad_w}) - #if stride[0] == 2 and r == 3: - # print('dx: ', dx[0,0,0,0]) # gradient for the weight dw = None if ctx.needs_input_grad[1]: dw = torch.empty_like(weight) triton.ops.einsum(f'nc(p*{stride[0]}+r-{padding[0]})(q*{stride[1]}+s-{padding[1]}),nkpq->kcrs', - input, dy, output = dw, mask = acc_bitmask) + x, dy, output = dw, mask = acc_bitmask) #print('dw: ', dw.view(-1)[0]) return dx, dw, None, None, None, None, None, None conv2d = _conv2d.apply @@ -100,7 +98,7 @@ class Conv2d(nn.Conv2d): def forward(self, input): #if self.kernel_size[0] == 3 and self.stride[0] != 1: - # print(self.padding, self.stride, input.size(), self.weight.size()) + #print(self.padding, self.stride, input.size(), self.weight.size()) # return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, @@ -127,13 +125,14 @@ def replace_conv2d(model, acc_bitmask = None): #torch.Size([128, 256, 8, 8]) torch.Size([512, 256, 3, 3]) if __name__ == '__main__': - #N, C, H, W, K, RS = 128, 64, 30, 30, 128, 3 + N, C, H, W, K, RS = 128, 64, 30, 30, 128, 1 #N, C, H, W, K, RS = 128, 128, 15, 15, 256, 3 - N, C, H, W, K, RS = 128, 256, 8, 8, 512, 3 - pad, stride = 1, 2 + #N, C, H, W, K, RS = 128, 256, 8, 8, 512, 3 + pad, stride = 0, 1 torch.manual_seed(0) x = torch.randn((N, C, H, W)).cuda() x.requires_grad_(True) + #x.data[:] = 1 # initialize layers torch.manual_seed(0) rconv2d = nn.Conv2d(C, K, RS, stride, pad, bias=False).cuda() @@ -156,9 +155,10 @@ if __name__ == '__main__': tdw = tconv2d.weight.grad.clone() x.grad.zero_() # print error - print((ry - ty).abs().max()) - print((rdx - tdx).abs().max()) - print((rdw - tdw).abs().max()) + diff = lambda x, y: (x - y).abs().max() + print(diff(ry, ty)) + print(diff(rdx, tdx)) + print(diff(rdw, tdw)) #print((rdx - tdx).abs()) #print((rdx[0,0,:,:] - tdx[0,0,:,:]))