[DOCS] Improved plots in tutorials
This commit is contained in:
@@ -14,7 +14,6 @@ square_confs = [
|
||||
y_vals = [16, 32, 64],
|
||||
y_lines = ['Block16', 'Block32', 'Block64'],
|
||||
ylabel = 'TFLOPS',
|
||||
loglog = False,
|
||||
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
||||
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
|
||||
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
@@ -65,7 +64,6 @@ square_confs = [
|
||||
y_vals = [16, 32, 64],
|
||||
y_lines = ['Block16', 'Block32', 'Block64'],
|
||||
ylabel = 'GBPS',
|
||||
loglog = False,
|
||||
plot_name = f'{layout_mode}-square',
|
||||
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)\
|
||||
|
@@ -9,7 +9,6 @@ confs = [
|
||||
y_vals = ['triton', 'torch'],
|
||||
y_lines = ['Triton', 'Torch'],
|
||||
ylabel = 'GBPS',
|
||||
loglog = False,
|
||||
plot_name = f'{mode}-2048',
|
||||
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}
|
||||
)\
|
||||
|
@@ -20,7 +20,6 @@ square_confs = [
|
||||
y_vals=["torch", "triton", "cutlass"],
|
||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
loglog=False,
|
||||
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
|
||||
args={
|
||||
"AT": AT,
|
||||
@@ -39,17 +38,16 @@ transformer_confs = [
|
||||
y_vals=["torch", "triton", "cutlass"],
|
||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
loglog=False,
|
||||
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
|
||||
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
||||
) for NK in [8192]\
|
||||
) for NK in [12288]\
|
||||
for i, x in enumerate(["N", "K"])\
|
||||
for M in [2048]
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
|
||||
@triton.testing.perf_report(transformer_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
|
||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||
if AT: a = a.t()
|
||||
|
@@ -4,7 +4,8 @@ import os
|
||||
import inspect
|
||||
import triton
|
||||
|
||||
def run_all(result_dir, with_plots, names):
|
||||
|
||||
def run_all(result_dir, names):
|
||||
if not os.path.exists(result_dir):
|
||||
os.makedirs(result_dir)
|
||||
for mod in os.listdir(os.path.dirname(os.path.realpath(__file__))):
|
||||
@@ -26,16 +27,17 @@ def run_all(result_dir, with_plots, names):
|
||||
curr_dir = os.path.join(curr_dir, name.replace('bench_', ''))
|
||||
if not os.path.exists(curr_dir):
|
||||
os.makedirs(curr_dir)
|
||||
bench.run(curr_dir, with_plots)
|
||||
bench.run(save_path=curr_dir)
|
||||
|
||||
|
||||
def main(args):
|
||||
parser = argparse.ArgumentParser(description="Run the benchmark suite.")
|
||||
parser.add_argument("-r", "--result-dir", type=str, default='results', required=False)
|
||||
parser.add_argument("-n", "--names", type=str, default='', required=False)
|
||||
parser.add_argument("-p", "--with-plots", dest='with_plots', action='store_true')
|
||||
parser.set_defaults(feature=False)
|
||||
args = parser.parse_args(args)
|
||||
run_all(args.result_dir, args.with_plots, args.names)
|
||||
run_all(args.result_dir, args.names)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1:])
|
||||
|
Reference in New Issue
Block a user