[PYTHON][EXAMPLES] Changed shape of einsum examples

This commit is contained in:
Philippe Tillet
2020-02-06 13:57:30 -05:00
committed by Philippe Tillet
parent 6d7cf35123
commit 558422c18a
2 changed files with 8 additions and 9 deletions

View File

@@ -13,7 +13,7 @@ configs = []
MNK = [
(512, 512 ,512),
(2048, 2048, 2048),
(8192, 8192, 8192),
#(8192, 8192, 8192),
# (64, 64, 64000),
# (64, 64, 128000),
@@ -68,7 +68,7 @@ NTHSE = [
# (128, 1024, 8, 64, 64),
# (128, 1024, 8, 128, 128),
# (128, 1024, 8, 256, 256),
(128, 1024, 8, 256, 512)
#(128, 1024, 8, 256, 512)
]
for N, T, H, S, E in NTHSE:
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
@@ -92,10 +92,10 @@ for N, C, H, K, R in NCHKR:
# 2D Dense convolution
NCHWKRS = [
(8, 64, 128, 128, 768, 3, 3),
#(8, 64, 128, 128, 768, 3, 3),
(8, 128, 64, 64, 256, 3, 3),
(8, 256, 32, 32, 512, 3, 3),
(8, 512, 32, 32, 1024, 3, 3)
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
@@ -108,10 +108,10 @@ for N, C, H, W, K, R, S in NCHWKRS:
# 3D Dense Convolution
NCDHWKTRS = [
(8, 32, 27, 100, 100, 64, 3, 3, 3),
(8, 64, 23, 48, 48, 256, 3, 3, 3),
(8, 256, 19, 22, 22, 640, 3, 3, 3),
(8, 640, 15, 36, 36, 384, 3, 3, 3)
#(8, 32, 27, 100, 100, 64, 3, 3, 3),
#(8, 64, 23, 48, 48, 256, 3, 3, 3),
#(8, 256, 19, 22, 22, 640, 3, 3, 3),
#(8, 640, 15, 36, 36, 384, 3, 3, 3)
]
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))

View File

@@ -646,5 +646,4 @@ __global__ void {name}(
return None, da, None, None, None
einsum = _einsum.apply