[PYTHON][EXAMPLES][EINSUM] Added group-convolution test/benchmark
This commit is contained in:
committed by
Philippe Tillet
parent
5bb977173f
commit
f22ad0064c
@@ -33,9 +33,9 @@ MNK = [
|
||||
|
||||
# (127008, 768, 576)
|
||||
]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b)
|
||||
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||
#for M, N, K in MNK:
|
||||
# matmul = lambda a, b: torch.matmul(a, b)
|
||||
# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||
#for M, N, K in MNK:
|
||||
# matmul = lambda a, b: torch.matmul(a.t(), b)
|
||||
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
||||
@@ -95,18 +95,22 @@ for N, C, H, K, R in NCHKR:
|
||||
# 2D Dense convolution
|
||||
NCHWKRS = [
|
||||
#(8, 64, 128, 128, 768, 3, 3),
|
||||
(128, 3, 32, 32, 64, 3, 3),
|
||||
#(8, 256, 32, 32, 512, 3, 3),
|
||||
#(128, 3, 32, 32, 64, 3, 3),
|
||||
(1, 1024, 32, 112, 112, 1024, 3, 3),
|
||||
#(8, 512, 32, 32, 1024, 3, 3)
|
||||
]
|
||||
for N, C, H, W, K, R, S in NCHWKRS:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b)
|
||||
for N, C, G, H, W, K, R, S in NCHWKRS:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), groups=G)
|
||||
transform_a = lambda a: a.view(N, G, C // G, H, W)
|
||||
transform_b = lambda b: b.view(C // G, R, S, G, K // G)
|
||||
transform_c = lambda c: c.view(N, K, H - R + 1, W - S + 1)
|
||||
configs += [([N, C, H, W],
|
||||
[K, C, R, S],
|
||||
[N, K, H - R + 1, W - R + 1],
|
||||
[C // G, R, S, K],
|
||||
[N, G, K // G, H - R + 1, W - S + 1],
|
||||
torch_fn,
|
||||
'nc(h+r)(w+s),kcrs->nkhw',
|
||||
dict())]
|
||||
'ngc(h+r)(w+s),crsgk->ngkhw',
|
||||
dict(), transform_a, transform_b, transform_c)]
|
||||
|
||||
|
||||
# 3D Dense Convolution
|
||||
NCDHWKTRS = [
|
||||
@@ -169,28 +173,33 @@ for N, C, H, W, K, R, S in NCHWKRS:
|
||||
|
||||
# Benchmark
|
||||
torch.set_num_threads(1)
|
||||
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
for a_shape, b_shape, c_shape, torch_fn, expr, arrays, \
|
||||
transform_a, transform_b, transform_c in configs:
|
||||
dtype = torch.cuda.FloatTensor
|
||||
# initialize input tensors
|
||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
||||
# triton output
|
||||
tc = torch.empty(c_shape, device=a.device)
|
||||
triton.ops.einsum(expr, a, b, tc, arrays = arrays, bench = True)
|
||||
# reference output
|
||||
if torch_fn:
|
||||
rc = torch_fn(a, b, **arrays)
|
||||
else:
|
||||
rc = torch.einsum(expr, a, b)
|
||||
# performance relative to equivalent matrix multiplication
|
||||
# triton output
|
||||
ta = a if transform_a is None else transform_a(a)
|
||||
tb = b if transform_b is None else transform_b(b)
|
||||
tc = torch.empty(c_shape, device=a.device)
|
||||
triton.ops.einsum(expr, ta, tb, tc, arrays = arrays, bench = True)
|
||||
ctx = triton.ops._einsum.registry[tc]
|
||||
tc = tc if transform_c is None else transform_c(tc)
|
||||
# performance relative to equivalent matrix multiplication
|
||||
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
||||
cmp_eqbmm = False
|
||||
cmp_eqbmm = True
|
||||
if cmp_eqbmm:
|
||||
a = torch.rand(B, M, K).type(dtype).cuda()
|
||||
b = torch.rand(B, K, N).type(dtype).cuda()
|
||||
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
|
||||
ratio = triton.ctx_registry[tmmc].forward_ms / ctx.forward_ms
|
||||
c = torch.empty((B, M, N), device=a.device).cuda()
|
||||
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, c, bench = True)
|
||||
ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms
|
||||
cmp_str = f'({ratio:4.2f})'
|
||||
else:
|
||||
cmp_str = ''
|
||||
|
Reference in New Issue
Block a user