[PYTHON][EINSUM] re-established auto-tuning

This commit is contained in:
Philippe Tillet
2020-04-09 11:01:57 -04:00
committed by Philippe Tillet
parent ec2cb2155e
commit 5bb977173f
2 changed files with 28 additions and 48 deletions

View File

@@ -13,18 +13,18 @@ configs = []
# Matrix multiplication
MNK = [
(1024, 1024, 1024),
(512, 512 ,512),
(2048, 2048, 2048),
(8192, 8192, 8192),
#(8192, 8192, 8192),
#(64, 64, 64000),
#(64, 64, 128000),
#(256, 256, 64000),
#(256, 256, 128000),
(64, 64, 64000),
(64, 64, 128000),
(256, 256, 64000),
(256, 256, 128000),
#(1536, 16, 1536),
#(1536, 32, 1536),
#(1536, 64, 1536),
(1536, 16, 1536),
(1536, 32, 1536),
(1536, 64, 1536),
# (1536, 128, 1536),
# (4096, 16, 4096),
# (4096, 32, 4096),
@@ -33,9 +33,9 @@ MNK = [
# (127008, 768, 576)
]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a, b)
# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b)
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a.t(), b)
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
@@ -94,8 +94,8 @@ for N, C, H, K, R in NCHKR:
# 2D Dense convolution
NCHWKRS = [
(8, 64, 128, 128, 768, 3, 3),
#(128, 3, 32, 32, 64, 3, 3),
#(8, 64, 128, 128, 768, 3, 3),
(128, 3, 32, 32, 64, 3, 3),
#(8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
@@ -160,39 +160,22 @@ for N, C, H, W, K, R, S in NCHWKRS:
b = b.permute(1, 0)
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
return torch.nn.functional.conv2d(a, b)
configs += [([N, C, H, W], [C, K], [N, K, H, W],
shift_conv,
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
{'sh': shift_h, 'sw': shift_w})]
NCHWKX = [
#(8, 64, 128, 128, 128, 7)
]
for N, C, H, W, K, X in NCHWKX:
off_h = np.array([0, 0, 0, 1, 2, 3, 4], dtype=np.int32)
off_w = np.array([0, 1, 3, 1, 3, 0, 4], dtype=np.int32)
R, S = 5, 5
def sparse_conv(a, b, **kwargs):
off_h, off_w = kwargs['off_h'], kwargs['off_w']
K, C, X = b.shape
cvtb = torch.zeros([K, C, R, S], dtype=b.dtype, device=b.device)
cvtb[:, :, off_h, off_w] = b
return torch.nn.functional.conv2d(a, cvtb)
configs += [([N, C, H, W], [K, C, X], [N, K, H - R + 1, W - S + 1],
sparse_conv,
'nc(h + off_h[x])(w + off_w[x]),kcx->nkhw',
{'off_h': off_h, 'off_w': off_w})]
configs += [([N, C, H, W],
[C, K],
[N, K, H, W],
shift_conv,
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
{'sh': shift_h, 'sw': shift_w})]
# Benchmark
torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
dtype = torch.cuda.HalfTensor
dtype = torch.cuda.FloatTensor
# initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()
# triton output
tc = torch.zeros(c_shape, dtype=a.dtype, device=a.device)
tc = torch.empty(c_shape, device=a.device)
triton.ops.einsum(expr, a, b, tc, arrays = arrays, bench = True)
# reference output
if torch_fn:
@@ -202,13 +185,12 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
# performance relative to equivalent matrix multiplication
ctx = triton.ops._einsum.registry[tc]
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
cmp_eqbmm = True
cmp_eqbmm = False
if cmp_eqbmm:
a = torch.rand(B, M, K).type(dtype).cuda()
b = torch.rand(B, K, N).type(dtype).cuda()
tmmc = torch.empty([B, M, N]).type(dtype).cuda()
triton.ops.einsum('bmk,bkn->bmn', a, b, tmmc, bench = True)
ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
ratio = triton.ctx_registry[tmmc].forward_ms / ctx.forward_ms
cmp_str = f'({ratio:4.2f})'
else:
cmp_str = ''

View File

@@ -329,16 +329,14 @@ __global__ void {name}(
#endif
}
"""
# print(src)
# compilation options
#TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16]
#TK = 16 if dtype==torch.float16 else 8
TM, TN, TB, TZ, TK = 128, 128, 1, 1, 16
TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16]
TK = 16 if dtype==torch.float16 else 8
defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
if mask is not None:
defines['MASK'] = '{0:#0{1}x}'.format(mask, 10)
# create kernel
ret = triton.kernel(src, defines=defines, num_warps=[4])
ret = triton.kernel(src, defines=defines)
# set constant
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
ret.set_constant('AD', delta_a)