[PYTHON][EXAMPLES][EINSUM] Added stride in CONV2D example

This commit is contained in:
Philippe Tillet
2020-04-10 00:14:31 -04:00
committed by Philippe Tillet
parent f22ad0064c
commit 7924642b78

View File

@@ -100,15 +100,18 @@ NCHWKRS = [
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, G, H, W, K, R, S in NCHWKRS:
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), groups=G)
stride = 2
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), stride=stride, groups=G)
P = (H - R + 1) // stride
Q = (W - S + 1) // stride
transform_a = lambda a: a.view(N, G, C // G, H, W)
transform_b = lambda b: b.view(C // G, R, S, G, K // G)
transform_c = lambda c: c.view(N, K, H - R + 1, W - S + 1)
transform_c = lambda c: c.view(N, K, P, Q)
configs += [([N, C, H, W],
[C // G, R, S, K],
[N, G, K // G, H - R + 1, W - S + 1],
[N, G, K // G, P, Q],
torch_fn,
'ngc(h+r)(w+s),crsgk->ngkhw',
'ngc(h*2+r)(w*2+s),crsgk->ngkhw',
dict(), transform_a, transform_b, transform_c)]