[PYTHON][EINSUM] re-established auto-tuning
This commit is contained in:
committed by
Philippe Tillet
parent
ec2cb2155e
commit
5bb977173f
@@ -13,18 +13,18 @@ configs = []
|
||||
|
||||
# Matrix multiplication
|
||||
MNK = [
|
||||
(1024, 1024, 1024),
|
||||
(512, 512 ,512),
|
||||
(2048, 2048, 2048),
|
||||
(8192, 8192, 8192),
|
||||
#(8192, 8192, 8192),
|
||||
|
||||
#(64, 64, 64000),
|
||||
#(64, 64, 128000),
|
||||
#(256, 256, 64000),
|
||||
#(256, 256, 128000),
|
||||
(64, 64, 64000),
|
||||
(64, 64, 128000),
|
||||
(256, 256, 64000),
|
||||
(256, 256, 128000),
|
||||
|
||||
#(1536, 16, 1536),
|
||||
#(1536, 32, 1536),
|
||||
#(1536, 64, 1536),
|
||||
(1536, 16, 1536),
|
||||
(1536, 32, 1536),
|
||||
(1536, 64, 1536),
|
||||
# (1536, 128, 1536),
|
||||
# (4096, 16, 4096),
|
||||
# (4096, 32, 4096),
|
||||
@@ -33,9 +33,9 @@ MNK = [
|
||||
|
||||
# (127008, 768, 576)
|
||||
]
|
||||
#for M, N, K in MNK:
|
||||
# matmul = lambda a, b: torch.matmul(a, b)
|
||||
# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b)
|
||||
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||
#for M, N, K in MNK:
|
||||
# matmul = lambda a, b: torch.matmul(a.t(), b)
|
||||
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
||||
@@ -94,8 +94,8 @@ for N, C, H, K, R in NCHKR:
|
||||
|
||||
# 2D Dense convolution
|
||||
NCHWKRS = [
|
||||
(8, 64, 128, 128, 768, 3, 3),
|
||||
#(128, 3, 32, 32, 64, 3, 3),
|
||||
#(8, 64, 128, 128, 768, 3, 3),
|
||||
(128, 3, 32, 32, 64, 3, 3),
|
||||
#(8, 256, 32, 32, 512, 3, 3),
|
||||
#(8, 512, 32, 32, 1024, 3, 3)
|
||||
]
|
||||
@@ -160,39 +160,22 @@ for N, C, H, W, K, R, S in NCHWKRS:
|
||||
b = b.permute(1, 0)
|
||||
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
|
||||
return torch.nn.functional.conv2d(a, b)
|
||||
configs += [([N, C, H, W], [C, K], [N, K, H, W],
|
||||
shift_conv,
|
||||
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
|
||||
{'sh': shift_h, 'sw': shift_w})]
|
||||
|
||||
NCHWKX = [
|
||||
#(8, 64, 128, 128, 128, 7)
|
||||
]
|
||||
for N, C, H, W, K, X in NCHWKX:
|
||||
off_h = np.array([0, 0, 0, 1, 2, 3, 4], dtype=np.int32)
|
||||
off_w = np.array([0, 1, 3, 1, 3, 0, 4], dtype=np.int32)
|
||||
R, S = 5, 5
|
||||
def sparse_conv(a, b, **kwargs):
|
||||
off_h, off_w = kwargs['off_h'], kwargs['off_w']
|
||||
K, C, X = b.shape
|
||||
cvtb = torch.zeros([K, C, R, S], dtype=b.dtype, device=b.device)
|
||||
cvtb[:, :, off_h, off_w] = b
|
||||
return torch.nn.functional.conv2d(a, cvtb)
|
||||
configs += [([N, C, H, W], [K, C, X], [N, K, H - R + 1, W - S + 1],
|
||||
sparse_conv,
|
||||
'nc(h + off_h[x])(w + off_w[x]),kcx->nkhw',
|
||||
{'off_h': off_h, 'off_w': off_w})]
|
||||
|
||||
configs += [([N, C, H, W],
|
||||
[C, K],
|
||||
[N, K, H, W],
|
||||
shift_conv,
|
||||
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
|
||||
{'sh': shift_h, 'sw': shift_w})]
|
||||
|
||||
# Benchmark
|
||||
torch.set_num_threads(1)
|
||||
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
dtype = torch.cuda.HalfTensor
|
||||
dtype = torch.cuda.FloatTensor
|
||||
# initialize input tensors
|
||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
||||
# triton output
|
||||
tc = torch.zeros(c_shape, dtype=a.dtype, device=a.device)
|
||||
tc = torch.empty(c_shape, device=a.device)
|
||||
triton.ops.einsum(expr, a, b, tc, arrays = arrays, bench = True)
|
||||
# reference output
|
||||
if torch_fn:
|
||||
@@ -202,13 +185,12 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
# performance relative to equivalent matrix multiplication
|
||||
ctx = triton.ops._einsum.registry[tc]
|
||||
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
||||
cmp_eqbmm = True
|
||||
cmp_eqbmm = False
|
||||
if cmp_eqbmm:
|
||||
a = torch.rand(B, M, K).type(dtype).cuda()
|
||||
b = torch.rand(B, K, N).type(dtype).cuda()
|
||||
tmmc = torch.empty([B, M, N]).type(dtype).cuda()
|
||||
triton.ops.einsum('bmk,bkn->bmn', a, b, tmmc, bench = True)
|
||||
ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms
|
||||
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
|
||||
ratio = triton.ctx_registry[tmmc].forward_ms / ctx.forward_ms
|
||||
cmp_str = f'({ratio:4.2f})'
|
||||
else:
|
||||
cmp_str = ''
|
||||
|
@@ -329,16 +329,14 @@ __global__ void {name}(
|
||||
#endif
|
||||
}
|
||||
"""
|
||||
# print(src)
|
||||
# compilation options
|
||||
#TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16]
|
||||
#TK = 16 if dtype==torch.float16 else 8
|
||||
TM, TN, TB, TZ, TK = 128, 128, 1, 1, 16
|
||||
TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16]
|
||||
TK = 16 if dtype==torch.float16 else 8
|
||||
defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||
if mask is not None:
|
||||
defines['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
||||
# create kernel
|
||||
ret = triton.kernel(src, defines=defines, num_warps=[4])
|
||||
ret = triton.kernel(src, defines=defines)
|
||||
# set constant
|
||||
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
||||
ret.set_constant('AD', delta_a)
|
||||
|
Reference in New Issue
Block a user