[PYTHON][EXAMPLES][EINSUM] Added stride in CONV2D example
This commit is contained in:
committed by
Philippe Tillet
parent
f22ad0064c
commit
7924642b78
@@ -100,15 +100,18 @@ NCHWKRS = [
|
||||
#(8, 512, 32, 32, 1024, 3, 3)
|
||||
]
|
||||
for N, C, G, H, W, K, R, S in NCHWKRS:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), groups=G)
|
||||
stride = 2
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), stride=stride, groups=G)
|
||||
P = (H - R + 1) // stride
|
||||
Q = (W - S + 1) // stride
|
||||
transform_a = lambda a: a.view(N, G, C // G, H, W)
|
||||
transform_b = lambda b: b.view(C // G, R, S, G, K // G)
|
||||
transform_c = lambda c: c.view(N, K, H - R + 1, W - S + 1)
|
||||
transform_c = lambda c: c.view(N, K, P, Q)
|
||||
configs += [([N, C, H, W],
|
||||
[C // G, R, S, K],
|
||||
[N, G, K // G, H - R + 1, W - S + 1],
|
||||
[N, G, K // G, P, Q],
|
||||
torch_fn,
|
||||
'ngc(h+r)(w+s),crsgk->ngkhw',
|
||||
'ngc(h*2+r)(w*2+s),crsgk->ngkhw',
|
||||
dict(), transform_a, transform_b, transform_c)]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user