[PYTHON][EXAMPLES][EINSUM] Updated configs for matmul

This commit is contained in:
Philippe Tillet
2020-04-10 12:42:48 -04:00
committed by Philippe Tillet
parent 7924642b78
commit c36ad6bf8a

View File

@@ -33,15 +33,15 @@ MNK = [
# (127008, 768, 576)
]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a, b)
# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b)
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict(), None, None, None)]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a.t(), b)
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict(), None, None, None)]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a, b.t())
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict(), None, None, None)]
# Relative attention
NTHSE = [
@@ -73,11 +73,11 @@ NTHSE = [
#(128, 1024, 8, 256, 512)
]
#for N, T, H, S, E in NTHSE:
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict(), None, None, None)]
#for N, T, H, S, E in NTHSE:
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict(), None, None, None)]
#for N, T, H, S, E in NTHSE:
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict(), None, None, None)]
# 1D Dense convolution
NCHKR = [
@@ -90,13 +90,13 @@ for N, C, H, K, R in NCHKR:
[N, K, H - R + 1],
torch_fn,
'nc(h+r),crk->nkh',
dict())]
dict(), None, None, None)]
# 2D Dense convolution
NCHWKRS = [
#(8, 64, 128, 128, 768, 3, 3),
#(128, 3, 32, 32, 64, 3, 3),
(1, 1024, 32, 112, 112, 1024, 3, 3),
#(1, 1024, 32, 112, 112, 1024, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, G, H, W, K, R, S in NCHWKRS:
@@ -129,7 +129,7 @@ for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
[N, K, D - T + 1, H - R + 1, W - R + 1],
torch_fn,
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
dict())]
dict(), None, None, None)]
# Shift convolution
@@ -172,7 +172,8 @@ for N, C, H, W, K, R, S in NCHWKRS:
[N, K, H, W],
shift_conv,
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
{'sh': shift_h, 'sw': shift_w})]
{'sh': shift_h, 'sw': shift_w},
None, None, None)]
# Benchmark
torch.set_num_threads(1)