[DOCS] Updates and improvements (#87)

This commit is contained in:
Philippe Tillet
2021-04-22 10:27:02 -04:00
committed by Philippe Tillet
parent 39f4730305
commit 29e33e50b7
8 changed files with 195 additions and 70 deletions

View File

@@ -56,7 +56,7 @@ You will specifically learn about:
# Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
#
# .. code-block:: python
# :force:
#
# pid_m = triton.program_id(0)
# pid_n = triton.program_id(1)
# rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
@@ -85,8 +85,8 @@ You will specifically learn about:
# .. code-block:: Python
#
# pid = triton.program_id(0);
# grid_m = (M + BLOCK_M - 1) / BLOCK_M;
# grid_n = (N + BLOCK_N - 1) / BLOCK_N;
# grid_m = (M + BLOCK_M - 1) // BLOCK_M;
# grid_n = (N + BLOCK_N - 1) // BLOCK_N;
# pid_m = pid / grid_n;
# pid_n = pid % grid_n;
#
@@ -95,15 +95,15 @@ You will specifically learn about:
# One possible solution is to launch blocks in an order that promotes data reuse.
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
#
# .. code-block:: C
# .. code-block:: python
#
# pid = triton.program_id(0);
# width = GROUP_M * grid_n;
# group_id = pid / width;
# group_id = pid // width;
# # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
# group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
# pid_m = group_id * GROUP_M + (pid % group_size);
# pid_n = (pid % width) / (group_size);
# pid_n = (pid % width) // (group_size);
#
# In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
#
@@ -237,7 +237,7 @@ print(triton.testing.allclose(c_0, c_1))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[8192], # different possible values for `x_name`
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['cublas', 'triton'], # possible keys for `y_name`
y_lines=["cuBLAS", "Triton"], # label name for the lines