[PYTHON][CORE] Deprecating Tensorflow support
This commit is contained in:
committed by
Philippe Tillet
parent
d7a781dd40
commit
404dd18333
@@ -187,11 +187,11 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
a = torch.rand(B, M, K).type(dtype).cuda()
|
||||
b = torch.rand(B, K, N).type(dtype).cuda()
|
||||
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
|
||||
ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
|
||||
ratio = triton.ctx_registry[tmmc].forward_ms / ctx.forward_ms
|
||||
cmp_str = f'({ratio:4.2f})'
|
||||
else:
|
||||
cmp_str = ''
|
||||
# test and benchmark
|
||||
bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3
|
||||
bench = 2. * B * M * N * K / ctx.forward_ms * 1e-3
|
||||
diff = (tc - rc).abs().max() / rc.abs().max()
|
||||
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')
|
||||
|
Reference in New Issue
Block a user