[PYTHON][EINSUM] Now handling reduction sizes that are not a multiple of

TK
This commit is contained in:
Philippe Tillet
2020-02-14 12:41:47 -05:00
committed by Philippe Tillet
parent fa4ec7ea65
commit 3816f2f259
5 changed files with 99 additions and 73 deletions

View File

@@ -5,6 +5,8 @@ import numpy as np
#import utils
from time import time
torch.manual_seed(0)
#torch.backends.cudnn.benchmark = True
configs = []
@@ -31,15 +33,15 @@ MNK = [
# (127008, 768, 576)
]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b)
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a.t(), b)
configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b.t())
configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a, b)
# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a.t(), b)
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a, b.t())
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
# Relative attention
NTHSE = [
@@ -70,16 +72,16 @@ NTHSE = [
# (128, 1024, 8, 256, 256),
#(128, 1024, 8, 256, 512)
]
for N, T, H, S, E in NTHSE:
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
for N, T, H, S, E in NTHSE:
configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
for N, T, H, S, E in NTHSE:
configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
#for N, T, H, S, E in NTHSE:
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
#for N, T, H, S, E in NTHSE:
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
#for N, T, H, S, E in NTHSE:
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
# 1D Dense convolution
NCHKR = [
(1, 1152, 12602, 512, 3)
#(1, 1152, 12602, 512, 3)
]
for N, C, H, K, R in NCHKR:
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
@@ -93,17 +95,17 @@ for N, C, H, K, R in NCHKR:
# 2D Dense convolution
NCHWKRS = [
#(8, 64, 128, 128, 768, 3, 3),
(8, 128, 64, 64, 256, 3, 3),
(8, 256, 32, 32, 512, 3, 3),
(128, 3, 32, 32, 64, 3, 3),
#(8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b)
configs += [([N, C, H, W],
[C, R, S, K],
[K, C, R, S],
[N, K, H - R + 1, W - R + 1],
torch_fn,
'nc(h+r)(w+s),crsk->nkhw',
'nc(h+r)(w+s),kcrs->nkhw',
dict())]
# 3D Dense Convolution
@@ -173,6 +175,7 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()
# triton output
print(a.size(), b.size())
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
# reference output
if torch_fn:
@@ -182,7 +185,7 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
# performance relative to equivalent matrix multiplication
ctx = triton.ctx_registry[tc]
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
cmp_eqbmm = True
cmp_eqbmm = False
if cmp_eqbmm:
a = torch.rand(B, M, K).type(dtype).cuda()
b = torch.rand(B, K, N).type(dtype).cuda()