[PYTHON] Renamed triton.core -> triton.language (#92)
This commit is contained in:
committed by
Philippe Tillet
parent
41410012e8
commit
bfc0a7587d
@@ -81,6 +81,22 @@ def random(shape, dtype, device):
|
||||
|
||||
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
|
||||
"""
|
||||
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
||||
the 20-th and 80-th performance percentile.
|
||||
|
||||
:param fn: Function to benchmark
|
||||
:type fn: Callable
|
||||
:param warmup: Warmup time (in ms)
|
||||
:type warmup: int
|
||||
:param rep: Repetition time (in ms)
|
||||
:type rep: int
|
||||
:param grad_to_none: Reset the gradient of the provided tensor to None
|
||||
:type grad_to_none: torch.tensor, optional
|
||||
:param percentiles: Performance percentile to return in addition to the median.
|
||||
:type percentiles: list[float]
|
||||
"""
|
||||
|
||||
# Estimate the runtime of the function
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
@@ -125,13 +141,16 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
|
||||
|
||||
|
||||
class Benchmark:
|
||||
"""
|
||||
This class is used by the :code:`perf_report` function to generate line plots with a concise API.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
x_names,
|
||||
x_vals,
|
||||
y_name,
|
||||
y_vals,
|
||||
y_lines,
|
||||
line_arg,
|
||||
line_vals,
|
||||
line_names,
|
||||
plot_name,
|
||||
args,
|
||||
xlabel='',
|
||||
@@ -139,12 +158,38 @@ class Benchmark:
|
||||
x_log=False,
|
||||
y_log=False,
|
||||
):
|
||||
"""
|
||||
Constructor
|
||||
|
||||
:param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value.
|
||||
:type x_names: List[str]
|
||||
:param x_vals: List of values to use for the arguments in :code:`x_names`.
|
||||
:type x_vals: List[Any]
|
||||
:param line_arg: Argument name for which different values correspond to different lines in the plot.
|
||||
:type line_arg: str
|
||||
:param line_vals: List of values to use for the arguments in :code:`line_arg`.
|
||||
:type line_vals: List[str]
|
||||
:param line_names: Label names for the different lines.
|
||||
:type line_names: List[str]
|
||||
:param plot_name: Name of the plot.
|
||||
:type plot_name: str
|
||||
:param args: List of arguments to remain fixed throughout the benchmark.
|
||||
:type args: List[str]
|
||||
:param xlabel: Label for the x axis of the plot.
|
||||
:type xlabel: str, optional
|
||||
:param ylabel: Label for the y axis of the plot.
|
||||
:type ylabel: str, optional
|
||||
:param x_log: Whether the x axis should be log scale.
|
||||
:type x_log: bool, optional
|
||||
:param y_log: Whether the y axis should be log scale.
|
||||
:type y_log: bool, optional
|
||||
"""
|
||||
self.x_names = x_names
|
||||
self.x_vals = x_vals
|
||||
self.x_log = x_log
|
||||
self.y_name = y_name
|
||||
self.y_vals = y_vals
|
||||
self.y_lines = y_lines
|
||||
self.line_arg = line_arg
|
||||
self.line_vals = line_vals
|
||||
self.line_names = line_names
|
||||
self.y_log = y_log
|
||||
# plot info
|
||||
self.xlabel = xlabel
|
||||
@@ -162,15 +207,15 @@ class Mark:
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import os
|
||||
y_mean = bench.y_lines
|
||||
y_min = [f'{x}-min' for x in bench.y_lines]
|
||||
y_max = [f'{x}-max' for x in bench.y_lines]
|
||||
y_mean = bench.line_names
|
||||
y_min = [f'{x}-min' for x in bench.line_names]
|
||||
y_max = [f'{x}-max' for x in bench.line_names]
|
||||
df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max)
|
||||
for x in bench.x_vals:
|
||||
x_args = {x_name: x for x_name in bench.x_names}
|
||||
row_mean, row_min, row_max = [], [], []
|
||||
for y in bench.y_vals:
|
||||
ret = self.fn(**x_args, **{bench.y_name: y}, **bench.args)
|
||||
for y in bench.line_vals:
|
||||
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
|
||||
try:
|
||||
y_mean, y_min, y_max = ret
|
||||
except TypeError:
|
||||
@@ -183,7 +228,7 @@ class Mark:
|
||||
plt.figure()
|
||||
ax = plt.subplot()
|
||||
x = bench.x_names[0]
|
||||
for y in bench.y_lines:
|
||||
for y in bench.line_names:
|
||||
y_min, y_max = df[y + '-min'], df[y + '-max']
|
||||
ax.plot(df[x], df[y], label=y)
|
||||
if y_min is not None and y_max is not None:
|
||||
@@ -199,7 +244,7 @@ class Mark:
|
||||
plt.show()
|
||||
if save_path:
|
||||
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
|
||||
df = df[[bench.x_names[0]] + bench.y_lines]
|
||||
df = df[[bench.x_names[0]] + bench.line_names]
|
||||
if print_data:
|
||||
print(df)
|
||||
if save_path:
|
||||
@@ -220,5 +265,11 @@ class Mark:
|
||||
|
||||
|
||||
def perf_report(benchmarks):
|
||||
"""
|
||||
Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.
|
||||
|
||||
:param benchmarks: Benchmarking configurations.
|
||||
:type benchmarks: List of :class:`Benchmark`
|
||||
"""
|
||||
wrapper = lambda fn: Mark(fn, benchmarks)
|
||||
return wrapper
|
||||
|
Reference in New Issue
Block a user