diff --git a/python/examples/einsum.py b/python/examples/einsum.py index e044088e7..1c6e078d1 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -33,15 +33,15 @@ MNK = [ # (127008, 768, 576) ] -#for M, N, K in MNK: -# matmul = lambda a, b: torch.matmul(a, b) -# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())] +for M, N, K in MNK: + matmul = lambda a, b: torch.matmul(a, b) + configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict(), None, None, None)] #for M, N, K in MNK: # matmul = lambda a, b: torch.matmul(a.t(), b) -# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())] +# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict(), None, None, None)] #for M, N, K in MNK: # matmul = lambda a, b: torch.matmul(a, b.t()) -# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())] +# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict(), None, None, None)] # Relative attention NTHSE = [ @@ -73,11 +73,11 @@ NTHSE = [ #(128, 1024, 8, 256, 512) ] #for N, T, H, S, E in NTHSE: -# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())] +# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict(), None, None, None)] #for N, T, H, S, E in NTHSE: -# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())] +# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict(), None, None, None)] #for N, T, H, S, E in NTHSE: -# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())] +# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict(), None, None, None)] # 1D Dense convolution NCHKR = [ @@ -90,13 +90,13 @@ for N, C, H, K, R in NCHKR: [N, K, H - R + 1], torch_fn, 'nc(h+r),crk->nkh', - dict())] + dict(), None, None, None)] # 2D Dense convolution NCHWKRS = [ #(8, 64, 128, 128, 768, 3, 3), #(128, 3, 32, 32, 64, 3, 3), - (1, 1024, 32, 112, 112, 1024, 3, 3), + #(1, 1024, 32, 112, 112, 1024, 3, 3), #(8, 512, 32, 32, 1024, 3, 3) ] for N, C, G, H, W, K, R, S in NCHWKRS: @@ -129,7 +129,7 @@ for N, C, D, H, W, K, T, R, S in NCDHWKTRS: [N, K, D - T + 1, H - R + 1, W - R + 1], torch_fn, 'nc(d+t)(h+r)(w+s),ctrsk->nkdhw', - dict())] + dict(), None, None, None)] # Shift convolution @@ -172,7 +172,8 @@ for N, C, H, W, K, R, S in NCHWKRS: [N, K, H, W], shift_conv, 'nc(h + sh[c])(w + sw[c]),ck->nkhw', - {'sh': shift_h, 'sw': shift_w})] + {'sh': shift_h, 'sw': shift_w}, + None, None, None)] # Benchmark torch.set_num_threads(1)