[PYTHON][EXAMPLES] Changed shape of einsum examples
This commit is contained in:
committed by
Philippe Tillet
parent
6d7cf35123
commit
558422c18a
@@ -13,7 +13,7 @@ configs = []
|
||||
MNK = [
|
||||
(512, 512 ,512),
|
||||
(2048, 2048, 2048),
|
||||
(8192, 8192, 8192),
|
||||
#(8192, 8192, 8192),
|
||||
|
||||
# (64, 64, 64000),
|
||||
# (64, 64, 128000),
|
||||
@@ -68,7 +68,7 @@ NTHSE = [
|
||||
# (128, 1024, 8, 64, 64),
|
||||
# (128, 1024, 8, 128, 128),
|
||||
# (128, 1024, 8, 256, 256),
|
||||
(128, 1024, 8, 256, 512)
|
||||
#(128, 1024, 8, 256, 512)
|
||||
]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
|
||||
@@ -92,10 +92,10 @@ for N, C, H, K, R in NCHKR:
|
||||
|
||||
# 2D Dense convolution
|
||||
NCHWKRS = [
|
||||
(8, 64, 128, 128, 768, 3, 3),
|
||||
#(8, 64, 128, 128, 768, 3, 3),
|
||||
(8, 128, 64, 64, 256, 3, 3),
|
||||
(8, 256, 32, 32, 512, 3, 3),
|
||||
(8, 512, 32, 32, 1024, 3, 3)
|
||||
#(8, 512, 32, 32, 1024, 3, 3)
|
||||
]
|
||||
for N, C, H, W, K, R, S in NCHWKRS:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
||||
@@ -108,10 +108,10 @@ for N, C, H, W, K, R, S in NCHWKRS:
|
||||
|
||||
# 3D Dense Convolution
|
||||
NCDHWKTRS = [
|
||||
(8, 32, 27, 100, 100, 64, 3, 3, 3),
|
||||
(8, 64, 23, 48, 48, 256, 3, 3, 3),
|
||||
(8, 256, 19, 22, 22, 640, 3, 3, 3),
|
||||
(8, 640, 15, 36, 36, 384, 3, 3, 3)
|
||||
#(8, 32, 27, 100, 100, 64, 3, 3, 3),
|
||||
#(8, 64, 23, 48, 48, 256, 3, 3, 3),
|
||||
#(8, 256, 19, 22, 22, 640, 3, 3, 3),
|
||||
#(8, 640, 15, 36, 36, 384, 3, 3, 3)
|
||||
]
|
||||
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
|
||||
|
@@ -646,5 +646,4 @@ __global__ void {name}(
|
||||
return None, da, None, None, None
|
||||
|
||||
|
||||
|
||||
einsum = _einsum.apply
|
Reference in New Issue
Block a user