[PYTHON][EXAMPLES][EINSUM] Added group-convolution test/benchmark

This commit is contained in:
Philippe Tillet
2020-04-09 23:37:39 -04:00
committed by Philippe Tillet
parent 5bb977173f
commit f22ad0064c

View File

@@ -33,9 +33,9 @@ MNK = [
# (127008, 768, 576) # (127008, 768, 576)
] ]
for M, N, K in MNK: #for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b) # matmul = lambda a, b: torch.matmul(a, b)
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())] # configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
#for M, N, K in MNK: #for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a.t(), b) # matmul = lambda a, b: torch.matmul(a.t(), b)
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())] # configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
@@ -95,18 +95,22 @@ for N, C, H, K, R in NCHKR:
# 2D Dense convolution # 2D Dense convolution
NCHWKRS = [ NCHWKRS = [
#(8, 64, 128, 128, 768, 3, 3), #(8, 64, 128, 128, 768, 3, 3),
(128, 3, 32, 32, 64, 3, 3), #(128, 3, 32, 32, 64, 3, 3),
#(8, 256, 32, 32, 512, 3, 3), (1, 1024, 32, 112, 112, 1024, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3) #(8, 512, 32, 32, 1024, 3, 3)
] ]
for N, C, H, W, K, R, S in NCHWKRS: for N, C, G, H, W, K, R, S in NCHWKRS:
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b) torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), groups=G)
transform_a = lambda a: a.view(N, G, C // G, H, W)
transform_b = lambda b: b.view(C // G, R, S, G, K // G)
transform_c = lambda c: c.view(N, K, H - R + 1, W - S + 1)
configs += [([N, C, H, W], configs += [([N, C, H, W],
[K, C, R, S], [C // G, R, S, K],
[N, K, H - R + 1, W - R + 1], [N, G, K // G, H - R + 1, W - S + 1],
torch_fn, torch_fn,
'nc(h+r)(w+s),kcrs->nkhw', 'ngc(h+r)(w+s),crsgk->ngkhw',
dict())] dict(), transform_a, transform_b, transform_c)]
# 3D Dense Convolution # 3D Dense Convolution
NCDHWKTRS = [ NCDHWKTRS = [
@@ -169,28 +173,33 @@ for N, C, H, W, K, R, S in NCHWKRS:
# Benchmark # Benchmark
torch.set_num_threads(1) torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: for a_shape, b_shape, c_shape, torch_fn, expr, arrays, \
transform_a, transform_b, transform_c in configs:
dtype = torch.cuda.FloatTensor dtype = torch.cuda.FloatTensor
# initialize input tensors # initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda() a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda() b = torch.rand(*b_shape).type(dtype).cuda()
# triton output
tc = torch.empty(c_shape, device=a.device)
triton.ops.einsum(expr, a, b, tc, arrays = arrays, bench = True)
# reference output # reference output
if torch_fn: if torch_fn:
rc = torch_fn(a, b, **arrays) rc = torch_fn(a, b, **arrays)
else: else:
rc = torch.einsum(expr, a, b) rc = torch.einsum(expr, a, b)
# performance relative to equivalent matrix multiplication # triton output
ta = a if transform_a is None else transform_a(a)
tb = b if transform_b is None else transform_b(b)
tc = torch.empty(c_shape, device=a.device)
triton.ops.einsum(expr, ta, tb, tc, arrays = arrays, bench = True)
ctx = triton.ops._einsum.registry[tc] ctx = triton.ops._einsum.registry[tc]
tc = tc if transform_c is None else transform_c(tc)
# performance relative to equivalent matrix multiplication
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
cmp_eqbmm = False cmp_eqbmm = True
if cmp_eqbmm: if cmp_eqbmm:
a = torch.rand(B, M, K).type(dtype).cuda() a = torch.rand(B, M, K).type(dtype).cuda()
b = torch.rand(B, K, N).type(dtype).cuda() b = torch.rand(B, K, N).type(dtype).cuda()
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True) c = torch.empty((B, M, N), device=a.device).cuda()
ratio = triton.ctx_registry[tmmc].forward_ms / ctx.forward_ms tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, c, bench = True)
ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms
cmp_str = f'({ratio:4.2f})' cmp_str = f'({ratio:4.2f})'
else: else:
cmp_str = '' cmp_str = ''