[DOCS] Various improvements and typo fixes

This commit is contained in:
Philippe Tillet
2021-03-29 02:35:13 -04:00
committed by Philippe Tillet
parent 3f6ba1020d
commit 1fdb465b71
17 changed files with 207 additions and 99 deletions

View File

@@ -229,7 +229,13 @@ def make_kernel(device, dtype):
cache = make_kernel.cache
if key not in cache:
defines = {'TYPE': dtype}
cache[key] = triton.kernel(src, device=device, defines=defines, autotune_vals=autotune_configs, autotune_key=autotune_key)
cache[key] = triton.kernel(
src,
device=device,
defines=defines,
autotune_configs=autotune_configs,
autotune_key=autotune_key,
)
return cache[key]
@@ -319,7 +325,7 @@ print(torch.allclose(c_0, c_1, rtol=1e-3, atol=1e-3))
# .. code-block:: bash
#
# export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/a
# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/
# pip uninstall -y triton
# pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
#
@@ -343,8 +349,8 @@ print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton', 'cutlass'], # possible keys for `y_name`
y_lines=["Torch", "Triton", 'CUTLASS'], # label name for the lines
y_vals=['cublas', 'triton', 'cutlass'], # possible keys for `y_name`
y_lines=["cuBLAS", "Triton", 'CUTLASS'], # label name for the lines
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={}
@@ -353,7 +359,7 @@ print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'torch':
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: dot(a, b))