[PYTHON][EINSUM] Now handling reduction sizes that are not a multiple of
TK
This commit is contained in:
committed by
Philippe Tillet
parent
fa4ec7ea65
commit
3816f2f259
@@ -5,6 +5,8 @@ import numpy as np
|
||||
#import utils
|
||||
from time import time
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
#torch.backends.cudnn.benchmark = True
|
||||
|
||||
configs = []
|
||||
@@ -31,15 +33,15 @@ MNK = [
|
||||
|
||||
# (127008, 768, 576)
|
||||
]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b)
|
||||
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a.t(), b)
|
||||
configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b.t())
|
||||
configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
|
||||
#for M, N, K in MNK:
|
||||
# matmul = lambda a, b: torch.matmul(a, b)
|
||||
# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||
#for M, N, K in MNK:
|
||||
# matmul = lambda a, b: torch.matmul(a.t(), b)
|
||||
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
||||
#for M, N, K in MNK:
|
||||
# matmul = lambda a, b: torch.matmul(a, b.t())
|
||||
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
|
||||
|
||||
# Relative attention
|
||||
NTHSE = [
|
||||
@@ -70,16 +72,16 @@ NTHSE = [
|
||||
# (128, 1024, 8, 256, 256),
|
||||
#(128, 1024, 8, 256, 512)
|
||||
]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
|
||||
#for N, T, H, S, E in NTHSE:
|
||||
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
|
||||
#for N, T, H, S, E in NTHSE:
|
||||
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
|
||||
#for N, T, H, S, E in NTHSE:
|
||||
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
|
||||
|
||||
# 1D Dense convolution
|
||||
NCHKR = [
|
||||
(1, 1152, 12602, 512, 3)
|
||||
#(1, 1152, 12602, 512, 3)
|
||||
]
|
||||
for N, C, H, K, R in NCHKR:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
|
||||
@@ -93,17 +95,17 @@ for N, C, H, K, R in NCHKR:
|
||||
# 2D Dense convolution
|
||||
NCHWKRS = [
|
||||
#(8, 64, 128, 128, 768, 3, 3),
|
||||
(8, 128, 64, 64, 256, 3, 3),
|
||||
(8, 256, 32, 32, 512, 3, 3),
|
||||
(128, 3, 32, 32, 64, 3, 3),
|
||||
#(8, 256, 32, 32, 512, 3, 3),
|
||||
#(8, 512, 32, 32, 1024, 3, 3)
|
||||
]
|
||||
for N, C, H, W, K, R, S in NCHWKRS:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b)
|
||||
configs += [([N, C, H, W],
|
||||
[C, R, S, K],
|
||||
[K, C, R, S],
|
||||
[N, K, H - R + 1, W - R + 1],
|
||||
torch_fn,
|
||||
'nc(h+r)(w+s),crsk->nkhw',
|
||||
'nc(h+r)(w+s),kcrs->nkhw',
|
||||
dict())]
|
||||
|
||||
# 3D Dense Convolution
|
||||
@@ -173,6 +175,7 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
||||
# triton output
|
||||
print(a.size(), b.size())
|
||||
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
|
||||
# reference output
|
||||
if torch_fn:
|
||||
@@ -182,7 +185,7 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
# performance relative to equivalent matrix multiplication
|
||||
ctx = triton.ctx_registry[tc]
|
||||
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
||||
cmp_eqbmm = True
|
||||
cmp_eqbmm = False
|
||||
if cmp_eqbmm:
|
||||
a = torch.rand(B, M, K).type(dtype).cuda()
|
||||
b = torch.rand(B, K, N).type(dtype).cuda()
|
||||
|
Reference in New Issue
Block a user