diff --git a/python/examples/einsum.py b/python/examples/einsum.py index a99329c79..e044088e7 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -100,15 +100,18 @@ NCHWKRS = [ #(8, 512, 32, 32, 1024, 3, 3) ] for N, C, G, H, W, K, R, S in NCHWKRS: - torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), groups=G) + stride = 2 + torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), stride=stride, groups=G) + P = (H - R + 1) // stride + Q = (W - S + 1) // stride transform_a = lambda a: a.view(N, G, C // G, H, W) transform_b = lambda b: b.view(C // G, R, S, G, K // G) - transform_c = lambda c: c.view(N, K, H - R + 1, W - S + 1) + transform_c = lambda c: c.view(N, K, P, Q) configs += [([N, C, H, W], [C // G, R, S, K], - [N, G, K // G, H - R + 1, W - S + 1], + [N, G, K // G, P, Q], torch_fn, - 'ngc(h+r)(w+s),crsgk->ngkhw', + 'ngc(h*2+r)(w*2+s),crsgk->ngkhw', dict(), transform_a, transform_b, transform_c)]