diff --git a/python/examples/einsum.py b/python/examples/einsum.py index ce6d49210..a99329c79 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -33,9 +33,9 @@ MNK = [ # (127008, 768, 576) ] -for M, N, K in MNK: - matmul = lambda a, b: torch.matmul(a, b) - configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())] +#for M, N, K in MNK: +# matmul = lambda a, b: torch.matmul(a, b) +# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())] #for M, N, K in MNK: # matmul = lambda a, b: torch.matmul(a.t(), b) # configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())] @@ -95,18 +95,22 @@ for N, C, H, K, R in NCHKR: # 2D Dense convolution NCHWKRS = [ #(8, 64, 128, 128, 768, 3, 3), - (128, 3, 32, 32, 64, 3, 3), - #(8, 256, 32, 32, 512, 3, 3), + #(128, 3, 32, 32, 64, 3, 3), + (1, 1024, 32, 112, 112, 1024, 3, 3), #(8, 512, 32, 32, 1024, 3, 3) ] -for N, C, H, W, K, R, S in NCHWKRS: - torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b) +for N, C, G, H, W, K, R, S in NCHWKRS: + torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), groups=G) + transform_a = lambda a: a.view(N, G, C // G, H, W) + transform_b = lambda b: b.view(C // G, R, S, G, K // G) + transform_c = lambda c: c.view(N, K, H - R + 1, W - S + 1) configs += [([N, C, H, W], - [K, C, R, S], - [N, K, H - R + 1, W - R + 1], + [C // G, R, S, K], + [N, G, K // G, H - R + 1, W - S + 1], torch_fn, - 'nc(h+r)(w+s),kcrs->nkhw', - dict())] + 'ngc(h+r)(w+s),crsgk->ngkhw', + dict(), transform_a, transform_b, transform_c)] + # 3D Dense Convolution NCDHWKTRS = [ @@ -169,28 +173,33 @@ for N, C, H, W, K, R, S in NCHWKRS: # Benchmark torch.set_num_threads(1) -for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: +for a_shape, b_shape, c_shape, torch_fn, expr, arrays, \ + transform_a, transform_b, transform_c in configs: dtype = torch.cuda.FloatTensor # initialize input tensors a = torch.rand(*a_shape).type(dtype).cuda() b = torch.rand(*b_shape).type(dtype).cuda() - # triton output - tc = torch.empty(c_shape, device=a.device) - triton.ops.einsum(expr, a, b, tc, arrays = arrays, bench = True) # reference output if torch_fn: rc = torch_fn(a, b, **arrays) else: rc = torch.einsum(expr, a, b) - # performance relative to equivalent matrix multiplication + # triton output + ta = a if transform_a is None else transform_a(a) + tb = b if transform_b is None else transform_b(b) + tc = torch.empty(c_shape, device=a.device) + triton.ops.einsum(expr, ta, tb, tc, arrays = arrays, bench = True) ctx = triton.ops._einsum.registry[tc] + tc = tc if transform_c is None else transform_c(tc) + # performance relative to equivalent matrix multiplication B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K - cmp_eqbmm = False + cmp_eqbmm = True if cmp_eqbmm: a = torch.rand(B, M, K).type(dtype).cuda() b = torch.rand(B, K, N).type(dtype).cuda() - tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True) - ratio = triton.ctx_registry[tmmc].forward_ms / ctx.forward_ms + c = torch.empty((B, M, N), device=a.device).cuda() + tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, c, bench = True) + ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms cmp_str = f'({ratio:4.2f})' else: cmp_str = ''