diff --git a/python/tests/test_matmul.py b/python/tests/test_matmul.py index 65b7f86a0..e00e6b5b1 100644 --- a/python/tests/test_matmul.py +++ b/python/tests/test_matmul.py @@ -83,14 +83,16 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50): def perf_op(dtype=th.float16, warmup=10, rep=50): AT, BT = False, False - configs = [(N, N, N) for N in [128, 8192]] - for M, N, K in configs: + import pandas as pd + df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH']) + Ns = [128, 256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192] + configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns] + for AT, BT, M, N, K in configs: a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5 b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5 if AT: a = a.t() if BT: b = b.t() - a = a[::,::] - b = b[::,::] TH_MS, TH_TFLOPS, _ = do_bench(lambda: th.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep) TT_MS, TT_TFLOPS, _ = do_bench(lambda: tt.ops.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep) - print((M, N, K), TH_MS, TT_MS) \ No newline at end of file + df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': TT_TFLOPS, 'TORCH': TH_TFLOPS}, ignore_index=True) + print(df) \ No newline at end of file