[PYTHON][EXAMPLES][EINSUM] Updated configs for matmul

This commit is contained in:
Philippe Tillet
2020-04-10 12:42:48 -04:00
committed by Philippe Tillet
parent 7924642b78
commit c36ad6bf8a

View File

@@ -33,15 +33,15 @@ MNK = [
# (127008, 768, 576) # (127008, 768, 576)
] ]
#for M, N, K in MNK: for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a, b) matmul = lambda a, b: torch.matmul(a, b)
# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())] configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict(), None, None, None)]
#for M, N, K in MNK: #for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a.t(), b) # matmul = lambda a, b: torch.matmul(a.t(), b)
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())] # configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict(), None, None, None)]
#for M, N, K in MNK: #for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a, b.t()) # matmul = lambda a, b: torch.matmul(a, b.t())
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())] # configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict(), None, None, None)]
# Relative attention # Relative attention
NTHSE = [ NTHSE = [
@@ -73,11 +73,11 @@ NTHSE = [
#(128, 1024, 8, 256, 512) #(128, 1024, 8, 256, 512)
] ]
#for N, T, H, S, E in NTHSE: #for N, T, H, S, E in NTHSE:
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())] # configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict(), None, None, None)]
#for N, T, H, S, E in NTHSE: #for N, T, H, S, E in NTHSE:
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())] # configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict(), None, None, None)]
#for N, T, H, S, E in NTHSE: #for N, T, H, S, E in NTHSE:
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())] # configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict(), None, None, None)]
# 1D Dense convolution # 1D Dense convolution
NCHKR = [ NCHKR = [
@@ -90,13 +90,13 @@ for N, C, H, K, R in NCHKR:
[N, K, H - R + 1], [N, K, H - R + 1],
torch_fn, torch_fn,
'nc(h+r),crk->nkh', 'nc(h+r),crk->nkh',
dict())] dict(), None, None, None)]
# 2D Dense convolution # 2D Dense convolution
NCHWKRS = [ NCHWKRS = [
#(8, 64, 128, 128, 768, 3, 3), #(8, 64, 128, 128, 768, 3, 3),
#(128, 3, 32, 32, 64, 3, 3), #(128, 3, 32, 32, 64, 3, 3),
(1, 1024, 32, 112, 112, 1024, 3, 3), #(1, 1024, 32, 112, 112, 1024, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3) #(8, 512, 32, 32, 1024, 3, 3)
] ]
for N, C, G, H, W, K, R, S in NCHWKRS: for N, C, G, H, W, K, R, S in NCHWKRS:
@@ -129,7 +129,7 @@ for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
[N, K, D - T + 1, H - R + 1, W - R + 1], [N, K, D - T + 1, H - R + 1, W - R + 1],
torch_fn, torch_fn,
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw', 'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
dict())] dict(), None, None, None)]
# Shift convolution # Shift convolution
@@ -172,7 +172,8 @@ for N, C, H, W, K, R, S in NCHWKRS:
[N, K, H, W], [N, K, H, W],
shift_conv, shift_conv,
'nc(h + sh[c])(w + sw[c]),ck->nkhw', 'nc(h + sh[c])(w + sw[c]),ck->nkhw',
{'sh': shift_h, 'sw': shift_w})] {'sh': shift_h, 'sw': shift_w},
None, None, None)]
# Benchmark # Benchmark
torch.set_num_threads(1) torch.set_num_threads(1)