[PYTHON][EXAMPLES][EINSUM] Updated configs for matmul
This commit is contained in:
committed by
Philippe Tillet
parent
7924642b78
commit
c36ad6bf8a
@@ -33,15 +33,15 @@ MNK = [
|
|||||||
|
|
||||||
# (127008, 768, 576)
|
# (127008, 768, 576)
|
||||||
]
|
]
|
||||||
#for M, N, K in MNK:
|
for M, N, K in MNK:
|
||||||
# matmul = lambda a, b: torch.matmul(a, b)
|
matmul = lambda a, b: torch.matmul(a, b)
|
||||||
# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict(), None, None, None)]
|
||||||
#for M, N, K in MNK:
|
#for M, N, K in MNK:
|
||||||
# matmul = lambda a, b: torch.matmul(a.t(), b)
|
# matmul = lambda a, b: torch.matmul(a.t(), b)
|
||||||
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict(), None, None, None)]
|
||||||
#for M, N, K in MNK:
|
#for M, N, K in MNK:
|
||||||
# matmul = lambda a, b: torch.matmul(a, b.t())
|
# matmul = lambda a, b: torch.matmul(a, b.t())
|
||||||
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
|
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict(), None, None, None)]
|
||||||
|
|
||||||
# Relative attention
|
# Relative attention
|
||||||
NTHSE = [
|
NTHSE = [
|
||||||
@@ -73,11 +73,11 @@ NTHSE = [
|
|||||||
#(128, 1024, 8, 256, 512)
|
#(128, 1024, 8, 256, 512)
|
||||||
]
|
]
|
||||||
#for N, T, H, S, E in NTHSE:
|
#for N, T, H, S, E in NTHSE:
|
||||||
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
|
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict(), None, None, None)]
|
||||||
#for N, T, H, S, E in NTHSE:
|
#for N, T, H, S, E in NTHSE:
|
||||||
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
|
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict(), None, None, None)]
|
||||||
#for N, T, H, S, E in NTHSE:
|
#for N, T, H, S, E in NTHSE:
|
||||||
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
|
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict(), None, None, None)]
|
||||||
|
|
||||||
# 1D Dense convolution
|
# 1D Dense convolution
|
||||||
NCHKR = [
|
NCHKR = [
|
||||||
@@ -90,13 +90,13 @@ for N, C, H, K, R in NCHKR:
|
|||||||
[N, K, H - R + 1],
|
[N, K, H - R + 1],
|
||||||
torch_fn,
|
torch_fn,
|
||||||
'nc(h+r),crk->nkh',
|
'nc(h+r),crk->nkh',
|
||||||
dict())]
|
dict(), None, None, None)]
|
||||||
|
|
||||||
# 2D Dense convolution
|
# 2D Dense convolution
|
||||||
NCHWKRS = [
|
NCHWKRS = [
|
||||||
#(8, 64, 128, 128, 768, 3, 3),
|
#(8, 64, 128, 128, 768, 3, 3),
|
||||||
#(128, 3, 32, 32, 64, 3, 3),
|
#(128, 3, 32, 32, 64, 3, 3),
|
||||||
(1, 1024, 32, 112, 112, 1024, 3, 3),
|
#(1, 1024, 32, 112, 112, 1024, 3, 3),
|
||||||
#(8, 512, 32, 32, 1024, 3, 3)
|
#(8, 512, 32, 32, 1024, 3, 3)
|
||||||
]
|
]
|
||||||
for N, C, G, H, W, K, R, S in NCHWKRS:
|
for N, C, G, H, W, K, R, S in NCHWKRS:
|
||||||
@@ -129,7 +129,7 @@ for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
|
|||||||
[N, K, D - T + 1, H - R + 1, W - R + 1],
|
[N, K, D - T + 1, H - R + 1, W - R + 1],
|
||||||
torch_fn,
|
torch_fn,
|
||||||
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
|
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
|
||||||
dict())]
|
dict(), None, None, None)]
|
||||||
|
|
||||||
|
|
||||||
# Shift convolution
|
# Shift convolution
|
||||||
@@ -172,7 +172,8 @@ for N, C, H, W, K, R, S in NCHWKRS:
|
|||||||
[N, K, H, W],
|
[N, K, H, W],
|
||||||
shift_conv,
|
shift_conv,
|
||||||
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
|
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
|
||||||
{'sh': shift_h, 'sw': shift_w})]
|
{'sh': shift_h, 'sw': shift_w},
|
||||||
|
None, None, None)]
|
||||||
|
|
||||||
# Benchmark
|
# Benchmark
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
Reference in New Issue
Block a user