[GENERAL] Improved caching mechanism:

* Now computing hash in libtriton
* Now only compiling a single pytorch hook per function signature
This commit is contained in:
Philippe Tillet
2020-02-20 20:09:33 -05:00
committed by Philippe Tillet
parent 30f77e9ec5
commit dfb844bf41
14 changed files with 538 additions and 435 deletions

View File

@@ -17,14 +17,14 @@ MNK = [
(2048, 2048, 2048),
#(8192, 8192, 8192),
# (64, 64, 64000),
# (64, 64, 128000),
# (256, 256, 64000),
# (256, 256, 128000),
(64, 64, 64000),
(64, 64, 128000),
(256, 256, 64000),
(256, 256, 128000),
# (1536, 16, 1536),
# (1536, 32, 1536),
# (1536, 64, 1536),
(1536, 16, 1536),
(1536, 32, 1536),
(1536, 64, 1536),
# (1536, 128, 1536),
# (4096, 16, 4096),
# (4096, 32, 4096),
@@ -33,9 +33,9 @@ MNK = [
# (127008, 768, 576)
]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a, b)
# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b)
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a.t(), b)
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
@@ -175,15 +175,15 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()
# triton output
print(a.size(), b.size())
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
tc = torch.empty(c_shape, device=a.device)
triton.ops.einsum(expr, a, b, tc, arrays = arrays, bench = True)
# reference output
if torch_fn:
rc = torch_fn(a, b, **arrays)
else:
rc = torch.einsum(expr, a, b)
# performance relative to equivalent matrix multiplication
ctx = triton.ctx_registry[tc]
ctx = triton.ops._einsum.registry[tc]
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
cmp_eqbmm = False
if cmp_eqbmm: