[DOCS] Fix formatting mistakes (#192)
This commit is contained in:
BIN
docs/getting-started/tutorials/grouped_vs_row_major_ordering.png
Normal file
BIN
docs/getting-started/tutorials/grouped_vs_row_major_ordering.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 465 KiB |
@@ -46,7 +46,7 @@ You will specifically learn about:
|
||||
#
|
||||
# The above algorithm is, actually, fairly straightforward to implement in Triton.
|
||||
# The main difficulty comes from the computation of the memory locations at which blocks
|
||||
# of :code:`A` and :code:`B` must be read in the inner loop. For that, we need
|
||||
# of :code:`A` and :code:`B` must be read in the inner loop. For that, we need
|
||||
# multi-dimensional pointer arithmetics.
|
||||
#
|
||||
# Pointer Arithmetics
|
||||
@@ -88,7 +88,7 @@ You will specifically learn about:
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]`
|
||||
# block of :code:`C`.
|
||||
# block of :code:`C`.
|
||||
# It is important to remember that the order in which these blocks are computed does
|
||||
# matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
|
||||
# a simple row-major ordering
|
||||
@@ -116,7 +116,7 @@ You will specifically learn about:
|
||||
# group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
|
||||
# pid_m = group_id * GROUP_M + (pid % group_size);
|
||||
# pid_n = (pid % width) // (group_size);
|
||||
|
||||
#
|
||||
# For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
||||
# we can see that if we compute the output in row-major ordering, we need to load 90
|
||||
# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
|
||||
@@ -310,8 +310,8 @@ a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
triton_output = matmul(a, b, activation=None)
|
||||
torch_output = torch.matmul(a, b)
|
||||
print(f"{triton_output=}")
|
||||
print(f"{torch_output=}")
|
||||
print(f"triton_output={triton_output}")
|
||||
print(f"torch_output={torch_output}")
|
||||
if triton.testing.allclose(triton_output, torch_output):
|
||||
print("✅ Triton and Torch match")
|
||||
else:
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 469 KiB |
Reference in New Issue
Block a user