From 5bb977173f90211ae43f293acda985b8b58d78d6 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 9 Apr 2020 11:01:57 -0400 Subject: [PATCH] [PYTHON][EINSUM] re-established auto-tuning --- python/examples/einsum.py | 68 ++++++++++++++----------------------- python/triton/ops/einsum.py | 8 ++--- 2 files changed, 28 insertions(+), 48 deletions(-) diff --git a/python/examples/einsum.py b/python/examples/einsum.py index 4c8b4ac56..ce6d49210 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -13,18 +13,18 @@ configs = [] # Matrix multiplication MNK = [ - (1024, 1024, 1024), + (512, 512 ,512), (2048, 2048, 2048), - (8192, 8192, 8192), + #(8192, 8192, 8192), - #(64, 64, 64000), - #(64, 64, 128000), - #(256, 256, 64000), - #(256, 256, 128000), + (64, 64, 64000), + (64, 64, 128000), + (256, 256, 64000), + (256, 256, 128000), - #(1536, 16, 1536), - #(1536, 32, 1536), - #(1536, 64, 1536), + (1536, 16, 1536), + (1536, 32, 1536), + (1536, 64, 1536), # (1536, 128, 1536), # (4096, 16, 4096), # (4096, 32, 4096), @@ -33,9 +33,9 @@ MNK = [ # (127008, 768, 576) ] -#for M, N, K in MNK: -# matmul = lambda a, b: torch.matmul(a, b) -# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())] +for M, N, K in MNK: + matmul = lambda a, b: torch.matmul(a, b) + configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())] #for M, N, K in MNK: # matmul = lambda a, b: torch.matmul(a.t(), b) # configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())] @@ -94,8 +94,8 @@ for N, C, H, K, R in NCHKR: # 2D Dense convolution NCHWKRS = [ - (8, 64, 128, 128, 768, 3, 3), - #(128, 3, 32, 32, 64, 3, 3), + #(8, 64, 128, 128, 768, 3, 3), + (128, 3, 32, 32, 64, 3, 3), #(8, 256, 32, 32, 512, 3, 3), #(8, 512, 32, 32, 1024, 3, 3) ] @@ -160,39 +160,22 @@ for N, C, H, W, K, R, S in NCHWKRS: b = b.permute(1, 0) b = b.reshape(b.shape[0], b.shape[1], 1, 1) return torch.nn.functional.conv2d(a, b) - configs += [([N, C, H, W], [C, K], [N, K, H, W], - shift_conv, - 'nc(h + sh[c])(w + sw[c]),ck->nkhw', - {'sh': shift_h, 'sw': shift_w})] - -NCHWKX = [ - #(8, 64, 128, 128, 128, 7) - ] -for N, C, H, W, K, X in NCHWKX: - off_h = np.array([0, 0, 0, 1, 2, 3, 4], dtype=np.int32) - off_w = np.array([0, 1, 3, 1, 3, 0, 4], dtype=np.int32) - R, S = 5, 5 - def sparse_conv(a, b, **kwargs): - off_h, off_w = kwargs['off_h'], kwargs['off_w'] - K, C, X = b.shape - cvtb = torch.zeros([K, C, R, S], dtype=b.dtype, device=b.device) - cvtb[:, :, off_h, off_w] = b - return torch.nn.functional.conv2d(a, cvtb) - configs += [([N, C, H, W], [K, C, X], [N, K, H - R + 1, W - S + 1], - sparse_conv, - 'nc(h + off_h[x])(w + off_w[x]),kcx->nkhw', - {'off_h': off_h, 'off_w': off_w})] - + configs += [([N, C, H, W], + [C, K], + [N, K, H, W], + shift_conv, + 'nc(h + sh[c])(w + sw[c]),ck->nkhw', + {'sh': shift_h, 'sw': shift_w})] # Benchmark torch.set_num_threads(1) for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: - dtype = torch.cuda.HalfTensor + dtype = torch.cuda.FloatTensor # initialize input tensors a = torch.rand(*a_shape).type(dtype).cuda() b = torch.rand(*b_shape).type(dtype).cuda() # triton output - tc = torch.zeros(c_shape, dtype=a.dtype, device=a.device) + tc = torch.empty(c_shape, device=a.device) triton.ops.einsum(expr, a, b, tc, arrays = arrays, bench = True) # reference output if torch_fn: @@ -202,13 +185,12 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: # performance relative to equivalent matrix multiplication ctx = triton.ops._einsum.registry[tc] B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K - cmp_eqbmm = True + cmp_eqbmm = False if cmp_eqbmm: a = torch.rand(B, M, K).type(dtype).cuda() b = torch.rand(B, K, N).type(dtype).cuda() - tmmc = torch.empty([B, M, N]).type(dtype).cuda() - triton.ops.einsum('bmk,bkn->bmn', a, b, tmmc, bench = True) - ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms + tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True) + ratio = triton.ctx_registry[tmmc].forward_ms / ctx.forward_ms cmp_str = f'({ratio:4.2f})' else: cmp_str = '' diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 9cee27d70..464d588ca 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -329,16 +329,14 @@ __global__ void {name}( #endif } """ - # print(src) # compilation options - #TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16] - #TK = 16 if dtype==torch.float16 else 8 - TM, TN, TB, TZ, TK = 128, 128, 1, 1, 16 + TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16] + TK = 16 if dtype==torch.float16 else 8 defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype} if mask is not None: defines['MASK'] = '{0:#0{1}x}'.format(mask, 10) # create kernel - ret = triton.kernel(src, defines=defines, num_warps=[4]) + ret = triton.kernel(src, defines=defines) # set constant if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT: ret.set_constant('AD', delta_a)