[PYTHON][CORE] Deprecating Tensorflow support

This commit is contained in:
Philippe Tillet
2020-02-10 04:19:17 -05:00
committed by Philippe Tillet
parent d7a781dd40
commit 404dd18333
5 changed files with 26 additions and 108 deletions

View File

@@ -187,11 +187,11 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
a = torch.rand(B, M, K).type(dtype).cuda()
b = torch.rand(B, K, N).type(dtype).cuda()
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
ratio = triton.ctx_registry[tmmc].forward_ms / ctx.forward_ms
cmp_str = f'({ratio:4.2f})'
else:
cmp_str = ''
# test and benchmark
bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3
bench = 2. * B * M * N * K / ctx.forward_ms * 1e-3
diff = (tc - rc).abs().max() / rc.abs().max()
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')