[DOCS] Various improvements and typo fixes
This commit is contained in:
committed by
Philippe Tillet
parent
3f6ba1020d
commit
1fdb465b71
@@ -229,7 +229,13 @@ def make_kernel(device, dtype):
|
||||
cache = make_kernel.cache
|
||||
if key not in cache:
|
||||
defines = {'TYPE': dtype}
|
||||
cache[key] = triton.kernel(src, device=device, defines=defines, autotune_vals=autotune_configs, autotune_key=autotune_key)
|
||||
cache[key] = triton.kernel(
|
||||
src,
|
||||
device=device,
|
||||
defines=defines,
|
||||
autotune_configs=autotune_configs,
|
||||
autotune_key=autotune_key,
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
|
||||
@@ -319,7 +325,7 @@ print(torch.allclose(c_0, c_1, rtol=1e-3, atol=1e-3))
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
|
||||
# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/a
|
||||
# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/
|
||||
# pip uninstall -y triton
|
||||
# pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
|
||||
#
|
||||
@@ -343,8 +349,8 @@ print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
||||
y_vals=['torch', 'triton', 'cutlass'], # possible keys for `y_name`
|
||||
y_lines=["Torch", "Triton", 'CUTLASS'], # label name for the lines
|
||||
y_vals=['cublas', 'triton', 'cutlass'], # possible keys for `y_name`
|
||||
y_lines=["cuBLAS", "Triton", 'CUTLASS'], # label name for the lines
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}
|
||||
@@ -353,7 +359,7 @@ print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
|
||||
def benchmark(M, N, K, provider):
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
if provider == 'torch':
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: dot(a, b))
|
||||
|
Reference in New Issue
Block a user