[GENERAL] Merged einsum feature branch. Various feature, performance

improvements and bugfixes:

* Added preliminary support for extended Einstein summation in PyTriton
* Significant performance improvement on FP32 kernels containing matrix
multiplication
* Added re-coalescing pass for FP16 kernels containing matrix
multiplication
* Various bugfixes
This commit is contained in:
Philippe Tillet
2020-01-16 12:09:50 -05:00
parent 50a52df489
commit f278d9741a
49 changed files with 1923 additions and 994 deletions

View File

@@ -1,92 +1,194 @@
#!/usr/bin/env python
import numpy as np
from enum import Enum
import triton
import torch
from torch.utils.cpp_extension import load
import numpy as np
#import utils
from time import time
class MODE(Enum):
TF = 1
TORCH = 2
#torch.backends.cudnn.benchmark = True
try:
import tensorflow as tf
mode = MODE.TF
except ModuleNotFoundError:
pass
configs = []
try:
import torch
mode = MODE.TORCH
except ModuleNotFoundError:
pass
# Matrix multiplication
MNK = [
(512, 512 ,512),
(2048, 2048, 2048),
(8192, 8192, 8192),
(64, 64, 64000),
(64, 64, 128000),
(256, 256, 64000),
(256, 256, 128000),
cases = []
# Matmul
cases += [[[4, 1024, 1024], [1024, 1024], [4, 1024, 1024], "btc,ck->btk"]]
# Attention
# cases += [[[4, 256, 8, 2, 64], [8, 2, 512, 64], [4, 256, 8, 2, 512], "bchak,hank->bchan"]]
(1536, 16, 1536),
(1536, 32, 1536),
(1536, 64, 1536),
(1536, 128, 1536),
(4096, 16, 4096),
(4096, 32, 4096),
(4096, 64, 4096),
(4096, 128, 4096),
#(127008, 768, 576)
]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b)
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a.t(), b)
configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b.t())
configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
if mode == MODE.TF:
sess = tf.InteractiveSession()
# Relative attention
NTHSE = [
#(16, 512, 1, 64, 64),
# (16, 512, 1, 128, 128),
# (16, 512, 1, 256, 256),
# (16, 512, 1, 256, 512),
#(16, 512, 8, 64, 64),
# (16, 512, 8, 128, 128),
# (16, 512, 8, 256, 256),
# (16, 512, 8, 256, 512),
for a_shape, b_shape, c_shape, einsum in cases:
# (64, 1024, 1, 64, 64),
#(64, 1024, 1, 128, 128),
# (64, 1024, 1, 256, 256),
# (64, 1024, 1, 256, 512),
# (64, 1024, 8, 64, 64),
#(64, 1024, 8, 128, 128),
# (64, 1024, 8, 256, 256),
# (64, 1024, 8, 256, 512),
A = np.random.uniform(-1.0, 1.0, a_shape).astype(np.float16).astype(np.float32)
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
# (128, 1024, 1, 64, 64),
# (128, 1024, 1, 128, 128),
# (128, 1024, 1, 256, 256),
#(128, 1024, 1, 256, 512),
# (128, 1024, 8, 64, 64),
# (128, 1024, 8, 128, 128),
# (128, 1024, 8, 256, 256),
#(128, 1024, 8, 256, 512)
]
for N, T, H, S, E in NTHSE:
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
for N, T, H, S, E in NTHSE:
configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
for N, T, H, S, E in NTHSE:
configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
# Execute (tensorflow)
if mode == MODE.TF:
a = tf.placeholder(tf.float32, a_shape, name="a")
b = tf.placeholder(tf.float32, b_shape, name="b")
e = tf.placeholder(tf.float32, c_shape, name="e")
c = triton.ops.einsum(einsum, a, b, 1)
da, db = tf.gradients(c, [a, b], e)
feed_dict = { a: A.astype(np.float32),
b: B.astype(np.float32),
e: E }
sess.run(tf.global_variables_initializer())
result = sess.run([c, da, db], feed_dict = feed_dict)
# Execute (torch)
if mode == MODE.TORCH:
a = torch.from_numpy(A).cuda()
b = torch.from_numpy(B).cuda()
e = torch.from_numpy(E).cuda()
a.requires_grad_(True)
b.requires_grad_(True)
c = triton.ops.einsum(einsum, a, b, 1)
torch.autograd.backward(c, e)
da = a.grad
db = b.grad
result = [c.cpu().detach().numpy(), da.cpu().detach().numpy(), db.cpu().detach().numpy()]
# benchmark
nanosec = triton.bench_registry[c]
ctx = triton.ctx_registry[c]
b, m, n, k = tuple((ctx.bmnk[i] for i in range(0, 4)))
ops = 2.*b*m*n*k
print('C TFLOPS:', ops / triton.bench_registry[c] * 1e-3)
#print('DA TFLOPS:', ops / triton.bench_registry[da] * 1e-3)
#print('DB TFLOPS:', ops / triton.bench_registry[db] * 1e-3)
# 1D Dense convolution
NCHKR = [
# (1, 1152, 12602, 512, 3)
]
for N, C, H, K, R in NCHKR:
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
configs += [([N, C, H],
[C, R, K],
[N, K, H - R + 1],
torch_fn,
'nc(h+r),crk->nkh',
dict())]
# test
ctx = triton.ctx_registry[c]
t_a = ctx.trans_a
t_b = ctx.trans_b
e_a = ctx.einsum_a
e_b = ctx.einsum_b
e_c = ctx.einsum_c
C = np.einsum(einsum, A, B)
if not t_a and not t_b: # NN
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
elif not t_a and t_b: # NT
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
DB = np.einsum(f"{e_c},{e_a}->{e_b}", E, A)
elif t_a and not t_b: # TN
DA = np.einsum(f"{e_b},{e_c}->{e_a}", B, E)
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
c, da, db = result[0], result[1], result[2]
print('C diff:', np.abs((C - c)).max())
print('DA diff:', np.abs((DA - da)).max())
print('DB diff:', np.abs((DB - db)).max())
# 2D Dense convolution
NCHWKRS = [
#(8, 64, 128, 128, 768, 3, 3),
#(8, 128, 64, 64, 256, 3, 3),
#(8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
configs += [([N, C, H, W],
[C, R, S, K],
[N, K, H - R + 1, W - R + 1],
torch_fn,
'nc(h+r)(w+s),crsk->nkhw',
dict())]
# 3D Dense Convolution
NCDHWKTRS = [
#(8, 32, 27, 100, 100, 64, 3, 3, 3),
#(8, 64, 23, 48, 48, 256, 3, 3, 3),
#(8, 256, 19, 22, 22, 640, 3, 3, 3),
#(8, 640, 15, 36, 36, 384, 3, 3, 3)
]
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
configs += [([N, C, D, H, W],
[C, T, R, S, K],
[N, K, D - T + 1, H - R + 1, W - R + 1],
torch_fn,
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
dict())]
# Shift convolution
shift_cuda = torch.utils.cpp_extension.load(
'shift_cuda', ['kernels/shift_cuda.cpp',
'kernels/shift_cuda_kernel.cu'],
extra_cflags=['-O3'])
class shift(torch.autograd.Function):
@staticmethod
def forward(ctx, x, shift):
ctx.save_for_backward(shift)
return shift_cuda.forward(x, shift)
@staticmethod
def backward(ctx, grad_output):
shift, = ctx.saved_tensors
grad_output = shift_cuda.backward(grad_output, shift)
return grad_output, None
NCHWKRS = [
#(8, 64, 128, 128, 128, 3, 3),
#(8, 128, 64, 64, 256, 3, 3),
#(8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2
shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2
def shift_conv(a, b, **kwargs):
shift_h, shift_w = kwargs['sh'], kwargs['sw']
shift_torch = np.column_stack((shift_w*-1, shift_h*-1))
shift_torch = torch.from_numpy(shift_torch).cuda()
a = shift.apply(a, shift_torch)
b = b.permute(1, 0)
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
return torch.nn.functional.conv2d(a, b)
configs += [([N, C, H, W],
[C, K],
[N, K, H, W],
shift_conv,
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
{'sh': shift_h, 'sw': shift_w})]
# Benchmark
torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
dtype = torch.cuda.HalfTensor
# initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()
# triton output
#ta = triton.ops._einsum.pad(a, [4,4,4,4])
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
# reference output
if torch_fn:
rc = torch_fn(a, b, **arrays)
else:
rc = torch.einsum(expr, a, b)
# performance relative to equivalent matrix multiplication
ctx = triton.ctx_registry[tc]
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
# a = torch.rand(B, M, K).type(dtype).cuda()
# b = torch.rand(B, K, N).type(dtype).cuda()
# tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
# ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
ratio = 0
# test and benchmark
bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3
diff = (tc - rc).abs().max() / rc.abs().max()
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} ({ratio:4.2f}); {diff:4.2f}')