[RUNTIME] Added auto-alignment mechanism (#71)
This PR adds an automatic memory alignment mechanism in the Triton runtime. Specifically, the JIT compiler detects the alignment (in bytes) of each pointer argument as well as the largest power of two divisor (between 1 and 16) of each integer argument. Proper .aligned and .multipleof attributes are then added to the Triton-IR on-the-fly for all auto-tunable kernels. There is a cache that remembers all the kernels compiled for each possible configuration. This PR also includes substantial cleaning of the Python API. This adds 2-3us overhead, mostly due to accessing integer #defines from the auto-tuned compilation options. The previous solution was slightly faster but hacky and potentially unsafe, so this is preferred for now.
This commit is contained in:
committed by
Philippe Tillet
parent
ff62f7fffc
commit
62835a0979
@@ -1,12 +1,19 @@
|
||||
import triton
|
||||
import torch
|
||||
import os
|
||||
|
||||
# square benchmarks
|
||||
def rounded_linspace(low, high, steps, div):
|
||||
ret = torch.linspace(low, high, steps)
|
||||
ret = (ret.int() + div - 1) // div * div
|
||||
ret = torch.unique(ret)
|
||||
return list(map(int, ret))
|
||||
|
||||
# Square benchmarks
|
||||
nt = {False: "n", True: "t"}
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["M", "N", "K"],
|
||||
x_vals=[512 * i for i in range(1, 16)],
|
||||
x_vals=rounded_linspace(512, 8192, 17, 128),
|
||||
y_name="provider",
|
||||
y_vals=["torch", "triton", "cutlass"],
|
||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||
@@ -17,16 +24,29 @@ square_confs = [
|
||||
) for AT in [False, True] for BT in [False, True]
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
||||
import os
|
||||
# Transformer training benchmarks
|
||||
transformer_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=[x],
|
||||
x_vals = rounded_linspace(NK//16, NK, 33, 128),
|
||||
y_name="provider",
|
||||
y_vals=["torch", "triton", "cutlass"],
|
||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
loglog=False,
|
||||
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
|
||||
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
||||
) for NK in [8192]\
|
||||
for i, x in enumerate(["N", "K"])\
|
||||
for M in [2048]
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=40):
|
||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||
if AT:
|
||||
a = a.t()
|
||||
if BT:
|
||||
b = b.t()
|
||||
if AT: a = a.t()
|
||||
if BT: b = b.t()
|
||||
num_flops = 2 * M * N * K
|
||||
if provider == "torch":
|
||||
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
||||
@@ -40,7 +60,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
||||
import subprocess
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
|
||||
# run program specified by CUTLASS_PROFILER env variable
|
||||
layout_a = "column" if AT else "row"
|
||||
layout_b = "column" if BT else "row"
|
||||
@@ -61,6 +80,7 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
||||
f"--warmup-iterations={warmup}",
|
||||
f"--profiling-iterations={rep}",
|
||||
f"--output={fname}",
|
||||
"--dist=uniform,min:0,max:1,scale:-1",
|
||||
"--verbose=false",
|
||||
]
|
||||
# run cmd
|
||||
@@ -70,6 +90,3 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
||||
cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
|
||||
return cutlass_tflops
|
||||
return None
|
||||
|
||||
if __name__ == "__main__":
|
||||
bench_op.run()
|
||||
|
@@ -38,4 +38,4 @@ def main(args):
|
||||
run_all(args.result_dir, args.with_plots, args.names)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1:])
|
||||
main(sys.argv[1:])
|
||||
|
Reference in New Issue
Block a user