Files
triton/python/examples/einsum.py

214 lines
7.3 KiB
Python
Raw Normal View History

import triton
import torch
from torch.utils.cpp_extension import load
import numpy as np
#import utils
from time import time
torch.manual_seed(0)
#torch.backends.cudnn.benchmark = True
configs = []
# Matrix multiplication
MNK = [
(512, 512 ,512),
(2048, 2048, 2048),
#(8192, 8192, 8192),
(64, 64, 64000),
(64, 64, 128000),
(256, 256, 64000),
(256, 256, 128000),
(1536, 16, 1536),
(1536, 32, 1536),
(1536, 64, 1536),
# (1536, 128, 1536),
# (4096, 16, 4096),
# (4096, 32, 4096),
# (4096, 64, 4096),
# (4096, 128, 4096),
# (127008, 768, 576)
]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b)
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict(), None, None, None)]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a.t(), b)
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict(), None, None, None)]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a, b.t())
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict(), None, None, None)]
# Relative attention
NTHSE = [
(16, 512, 1, 64, 64),
# (16, 512, 1, 128, 128),
# (16, 512, 1, 256, 256),
# (16, 512, 1, 256, 512),
(16, 512, 8, 64, 64),
# (16, 512, 8, 128, 128),
# (16, 512, 8, 256, 256),
# (16, 512, 8, 256, 512),
# (64, 1024, 1, 64, 64),
(64, 1024, 1, 128, 128),
# (64, 1024, 1, 256, 256),
# (64, 1024, 1, 256, 512),
# (64, 1024, 8, 64, 64),
(64, 1024, 8, 128, 128),
# (64, 1024, 8, 256, 256),
# (64, 1024, 8, 256, 512),
# (128, 1024, 1, 64, 64),
# (128, 1024, 1, 128, 128),
# (128, 1024, 1, 256, 256),
(128, 1024, 1, 256, 512),
# (128, 1024, 8, 64, 64),
# (128, 1024, 8, 128, 128),
# (128, 1024, 8, 256, 256),
#(128, 1024, 8, 256, 512)
]
#for N, T, H, S, E in NTHSE:
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict(), None, None, None)]
#for N, T, H, S, E in NTHSE:
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict(), None, None, None)]
#for N, T, H, S, E in NTHSE:
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict(), None, None, None)]
# 1D Dense convolution
NCHKR = [
#(1, 1152, 12602, 512, 3)
]
for N, C, H, K, R in NCHKR:
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
configs += [([N, C, H],
[C, R, K],
[N, K, H - R + 1],
torch_fn,
'nc(h+r),crk->nkh',
dict(), None, None, None)]
# 2D Dense convolution
NCHWKRS = [
#(8, 64, 128, 128, 768, 3, 3),
#(128, 3, 32, 32, 64, 3, 3),
#(1, 1024, 32, 112, 112, 1024, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, G, H, W, K, R, S in NCHWKRS:
stride = 2
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), stride=stride, groups=G)
P = (H - R + 1) // stride
Q = (W - S + 1) // stride
transform_a = lambda a: a.view(N, G, C // G, H, W)
transform_b = lambda b: b.view(C // G, R, S, G, K // G)
transform_c = lambda c: c.view(N, K, P, Q)
configs += [([N, C, H, W],
[C // G, R, S, K],
[N, G, K // G, P, Q],
torch_fn,
'ngc(h*2+r)(w*2+s),crsgk->ngkhw',
dict(), transform_a, transform_b, transform_c)]
# 3D Dense Convolution
NCDHWKTRS = [
#(8, 32, 27, 100, 100, 64, 3, 3, 3),
#(8, 64, 23, 48, 48, 256, 3, 3, 3),
#(8, 256, 19, 22, 22, 640, 3, 3, 3),
#(8, 640, 15, 36, 36, 384, 3, 3, 3)
]
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
configs += [([N, C, D, H, W],
[C, T, R, S, K],
[N, K, D - T + 1, H - R + 1, W - R + 1],
torch_fn,
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
dict(), None, None, None)]
# Shift convolution
shift_cuda = torch.utils.cpp_extension.load(
'shift_cuda', ['kernels/shift_cuda.cpp',
'kernels/shift_cuda_kernel.cu'],
extra_cflags=['-O3'])
class shift(torch.autograd.Function):
@staticmethod
def forward(ctx, x, shift):
ctx.save_for_backward(shift)
return shift_cuda.forward(x, shift)
@staticmethod
def backward(ctx, grad_output):
shift, = ctx.saved_tensors
grad_output = shift_cuda.backward(grad_output, shift)
return grad_output, None
NCHWKRS = [
#(8, 64, 128, 128, 128, 3, 3),
#(8, 128, 64, 64, 256, 3, 3),
#(8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2
shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2
def shift_conv(a, b, **kwargs):
shift_h, shift_w = kwargs['sh'], kwargs['sw']
shift_torch = np.column_stack((shift_w*-1, shift_h*-1))
shift_torch = torch.from_numpy(shift_torch).cuda()
a = shift.apply(a, shift_torch)
b = b.permute(1, 0)
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
return torch.nn.functional.conv2d(a, b)
configs += [([N, C, H, W],
[C, K],
[N, K, H, W],
shift_conv,
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
{'sh': shift_h, 'sw': shift_w},
None, None, None)]
# Benchmark
torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays, \
transform_a, transform_b, transform_c in configs:
dtype = torch.cuda.FloatTensor
# initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()
# reference output
if torch_fn:
rc = torch_fn(a, b, **arrays)
else:
rc = torch.einsum(expr, a, b)
# triton output
ta = a if transform_a is None else transform_a(a)
tb = b if transform_b is None else transform_b(b)
tc = torch.empty(c_shape, device=a.device)
triton.ops.einsum(expr, ta, tb, tc, arrays = arrays, bench = True)
ctx = triton.ops._einsum.registry[tc]
tc = tc if transform_c is None else transform_c(tc)
# performance relative to equivalent matrix multiplication
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
cmp_eqbmm = True
if cmp_eqbmm:
a = torch.rand(B, M, K).type(dtype).cuda()
b = torch.rand(B, K, N).type(dtype).cuda()
c = torch.empty((B, M, N), device=a.device).cuda()
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, c, bench = True)
ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms
cmp_str = f'({ratio:4.2f})'
else:
cmp_str = ''
# test and benchmark
bench = 2. * B * M * N * K / ctx.forward_ms * 1e-3
diff = (tc - rc).abs().max() / rc.abs().max()
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')