[GH-PAGES] Updated website
@@ -29,18 +29,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compute Kernel\n\nThe above algorithm is, actually, fairly straightforward to implement in Triton.\nThe main difficulty comes from the computation of the memory locations at which blocks\n of :code:`A` and :code:`B` must be read in the inner loop. For that, we need\nmulti-dimensional pointer arithmetics.\n\n### Pointer Arithmetics\n\nFor a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b\ny :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.\nTherefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and\n:code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:\n\n .. code-block:: python\n\n &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = A + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);\n &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = B + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);\n\nWhich means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:\n\n .. code-block:: python\n\n pid_m = triton.program_id(0)\n pid_n = triton.program_id(1)\n rm = pid_m * BLOCK_SIZE_M + triton.arange(0, BLOCK_SIZE_M)\n rn = pid_n * BLOCK_SIZE_N + triton.arange(0, BLOCK_SIZE_N)\n rk = triton.arange(0, BLOCK_SIZE_K)\n // pointer for A operand\n pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);\n // pointer for B operand\n pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);\n\nAnd then updated in the inner loop as follows:\n\n .. code-block:: python\n\n pa += BLOCK_SIZE_K * stride_a_1;\n pb += BLOCK_SIZE_K * stride_b_0;\n\n\n### L2 Cache Optimizations\n\nAs mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]`\n block of :code:`C`.\nIt is important to remember that the order in which these blocks are computed does\nmatter, since it affects the L2 cache hit rate of our program. and unfortunately, a\na simple row-major ordering\n\n .. code-block:: Python\n\n pid = triton.program_id(0);\n grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;\n grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;\n pid_m = pid / grid_n;\n pid_n = pid % grid_n;\n\nis just not going to cut it.\n\nOne possible solution is to launch blocks in an order that promotes data reuse.\nThis can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before\nswitching to the next column:\n\n .. code-block:: python\n\n pid = triton.program_id(0);\n width = GROUP_M * grid_n;\n group_id = pid // width;\n # we need to handle the case where M % (GROUP_M*BLOCK_SIZE_M) != 0\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M);\n pid_m = group_id * GROUP_M + (pid % group_size);\n pid_n = (pid % width) // (group_size);\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# For example, in the following matmul where each matrix is 9 blocks by 9 blocks,\n# we can see that if we compute the output in row-major ordering, we need to load 90\n# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped\n# ordering, we only need to load 54 blocks.\n# .. image:: grouped_vs_row_major_ordering.png\n#\n# In practice, this can improve the performance of our matrix multiplication kernel by\n# more than 10\\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).\n#"
|
||||
"## Compute Kernel\n\nThe above algorithm is, actually, fairly straightforward to implement in Triton.\nThe main difficulty comes from the computation of the memory locations at which blocks\nof :code:`A` and :code:`B` must be read in the inner loop. For that, we need\nmulti-dimensional pointer arithmetics.\n\n### Pointer Arithmetics\n\nFor a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b\ny :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.\nTherefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and\n:code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:\n\n .. code-block:: python\n\n &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = A + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);\n &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = B + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);\n\nWhich means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:\n\n .. code-block:: python\n\n pid_m = triton.program_id(0)\n pid_n = triton.program_id(1)\n rm = pid_m * BLOCK_SIZE_M + triton.arange(0, BLOCK_SIZE_M)\n rn = pid_n * BLOCK_SIZE_N + triton.arange(0, BLOCK_SIZE_N)\n rk = triton.arange(0, BLOCK_SIZE_K)\n // pointer for A operand\n pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);\n // pointer for B operand\n pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);\n\nAnd then updated in the inner loop as follows:\n\n .. code-block:: python\n\n pa += BLOCK_SIZE_K * stride_a_1;\n pb += BLOCK_SIZE_K * stride_b_0;\n\n\n### L2 Cache Optimizations\n\nAs mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]`\nblock of :code:`C`.\nIt is important to remember that the order in which these blocks are computed does\nmatter, since it affects the L2 cache hit rate of our program. and unfortunately, a\na simple row-major ordering\n\n .. code-block:: Python\n\n pid = triton.program_id(0);\n grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;\n grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;\n pid_m = pid / grid_n;\n pid_n = pid % grid_n;\n\nis just not going to cut it.\n\nOne possible solution is to launch blocks in an order that promotes data reuse.\nThis can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before\nswitching to the next column:\n\n .. code-block:: python\n\n pid = triton.program_id(0);\n width = GROUP_M * grid_n;\n group_id = pid // width;\n # we need to handle the case where M % (GROUP_M*BLOCK_SIZE_M) != 0\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M);\n pid_m = group_id * GROUP_M + (pid % group_size);\n pid_n = (pid % width) // (group_size);\n\nFor example, in the following matmul where each matrix is 9 blocks by 9 blocks,\nwe can see that if we compute the output in row-major ordering, we need to load 90\nblocks into SRAM to compute the first 9 output blocks, but if we do it in grouped\nordering, we only need to load 54 blocks.\n .. image:: grouped_vs_row_major_ordering.png\n\nIn practice, this can improve the performance of our matrix multiplication kernel by\nmore than 10\\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).\n\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -94,7 +83,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\na = torch.randn((512, 512), device='cuda', dtype=torch.float16)\nb = torch.randn((512, 512), device='cuda', dtype=torch.float16)\ntriton_output = matmul(a, b, activation=None)\ntorch_output = torch.matmul(a, b)\nprint(f\"{triton_output=}\")\nprint(f\"{torch_output=}\")\nif triton.testing.allclose(triton_output, torch_output):\n print(\"\u2705 Triton and Torch match\")\nelse:\n print(\"\u274c Triton and Torch differ\")"
|
||||
"torch.manual_seed(0)\na = torch.randn((512, 512), device='cuda', dtype=torch.float16)\nb = torch.randn((512, 512), device='cuda', dtype=torch.float16)\ntriton_output = matmul(a, b, activation=None)\ntorch_output = torch.matmul(a, b)\nprint(f\"triton_output={triton_output}\")\nprint(f\"torch_output={torch_output}\")\nif triton.testing.allclose(triton_output, torch_output):\n print(\"\u2705 Triton and Torch match\")\nelse:\n print(\"\u274c Triton and Torch differ\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@@ -116,7 +116,7 @@ You will specifically learn about:
|
||||
# group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
|
||||
# pid_m = group_id * GROUP_M + (pid % group_size);
|
||||
# pid_n = (pid % width) // (group_size);
|
||||
|
||||
#
|
||||
# For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
||||
# we can see that if we compute the output in row-major ordering, we need to load 90
|
||||
# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
|
||||
@@ -310,8 +310,8 @@ a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
triton_output = matmul(a, b, activation=None)
|
||||
torch_output = torch.matmul(a, b)
|
||||
print(f"{triton_output=}")
|
||||
print(f"{torch_output=}")
|
||||
print(f"triton_output={triton_output}")
|
||||
print(f"torch_output={torch_output}")
|
||||
if triton.testing.allclose(triton_output, torch_output):
|
||||
print("✅ Triton and Torch match")
|
||||
else:
|
||||
|
BIN
_images/grouped_vs_row_major_ordering.png
Normal file
After Width: | Height: | Size: 465 KiB |
Before Width: | Height: | Size: 25 KiB After Width: | Height: | Size: 25 KiB |
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 16 KiB |
Before Width: | Height: | Size: 37 KiB After Width: | Height: | Size: 37 KiB |
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 24 KiB |
Before Width: | Height: | Size: 56 KiB After Width: | Height: | Size: 55 KiB |
Before Width: | Height: | Size: 32 KiB After Width: | Height: | Size: 32 KiB |
@@ -231,10 +231,10 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
|
||||
|
||||
vector-add-performance:
|
||||
size Triton Torch
|
||||
0 4096.0 9.600000 9.600000
|
||||
0 4096.0 9.540372 9.600000
|
||||
1 8192.0 19.200000 19.200000
|
||||
2 16384.0 38.400001 38.400001
|
||||
3 32768.0 63.999998 63.999998
|
||||
3 32768.0 76.800002 76.800002
|
||||
4 65536.0 127.999995 127.999995
|
||||
5 131072.0 219.428568 219.428568
|
||||
6 262144.0 341.333321 384.000001
|
||||
@@ -254,7 +254,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 10.981 seconds)
|
||||
**Total running time of the script:** ( 0 minutes 10.987 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
|
||||
|
@@ -306,11 +306,11 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
|
||||
3 640.0 682.666684 640.000002 160.000000
|
||||
4 768.0 702.171410 664.216187 163.839992
|
||||
.. ... ... ... ...
|
||||
93 12160.0 812.359066 405.755985 198.936606
|
||||
94 12288.0 812.429770 415.661740 199.197579
|
||||
95 12416.0 810.840807 412.149375 198.854847
|
||||
96 12544.0 810.925276 412.971190 199.111113
|
||||
97 12672.0 811.007961 412.097543 199.167004
|
||||
93 12160.0 812.359066 405.755985 199.038365
|
||||
94 12288.0 812.429770 415.661740 199.298541
|
||||
95 12416.0 810.840807 412.149375 198.954424
|
||||
96 12544.0 809.290334 412.546756 199.209928
|
||||
97 12672.0 809.389265 412.097543 199.167004
|
||||
|
||||
[98 rows x 4 columns]
|
||||
|
||||
@@ -329,7 +329,7 @@ In the above plot, we can see that:
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 1 minutes 12.654 seconds)
|
||||
**Total running time of the script:** ( 1 minutes 12.618 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
|
||||
|
@@ -59,7 +59,7 @@ algorithm to multiply a (MxK) by a (KxN) matrix:
|
||||
|
||||
where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 44-119
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 44-129
|
||||
|
||||
Compute Kernel
|
||||
----------------
|
||||
@@ -137,26 +137,14 @@ switching to the next column:
|
||||
pid_m = group_id * GROUP_M + (pid % group_size);
|
||||
pid_n = (pid % width) // (group_size);
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 119-130
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
# For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
||||
# we can see that if we compute the output in row-major ordering, we need to load 90
|
||||
# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
|
||||
# ordering, we only need to load 54 blocks.
|
||||
# .. image:: grouped_vs_row_major_ordering.png
|
||||
#
|
||||
# In practice, this can improve the performance of our matrix multiplication kernel by
|
||||
# more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
||||
#
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
||||
we can see that if we compute the output in row-major ordering, we need to load 90
|
||||
blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
|
||||
ordering, we only need to load 54 blocks.
|
||||
.. image:: grouped_vs_row_major_ordering.png
|
||||
|
||||
In practice, this can improve the performance of our matrix multiplication kernel by
|
||||
more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 131-134
|
||||
@@ -374,8 +362,8 @@ We can test our custom matrix multiplication operation against a native torch im
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
triton_output = matmul(a, b, activation=None)
|
||||
torch_output = torch.matmul(a, b)
|
||||
print(f"{triton_output=}")
|
||||
print(f"{torch_output=}")
|
||||
print(f"triton_output={triton_output}")
|
||||
print(f"torch_output={torch_output}")
|
||||
if triton.testing.allclose(triton_output, torch_output):
|
||||
print("✅ Triton and Torch match")
|
||||
else:
|
||||
@@ -484,36 +472,36 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
|
||||
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
||||
0 128.0 0.455111 ... 0.512000 0.512000
|
||||
1 256.0 2.730667 ... 3.276800 2.978909
|
||||
2 384.0 7.372800 ... 8.507077 8.507077
|
||||
3 512.0 14.563555 ... 16.384000 15.420235
|
||||
2 384.0 7.372800 ... 7.899428 7.899428
|
||||
3 512.0 14.563555 ... 16.384000 16.384000
|
||||
4 640.0 22.260869 ... 24.380953 24.380953
|
||||
5 768.0 32.768000 ... 34.028308 34.028308
|
||||
6 896.0 39.025776 ... 40.140799 35.123201
|
||||
6 896.0 39.025776 ... 40.140799 35.150663
|
||||
7 1024.0 49.932191 ... 52.428801 52.428801
|
||||
8 1152.0 44.566925 ... 46.656000 45.938215
|
||||
8 1152.0 44.566925 ... 46.656000 46.656000
|
||||
9 1280.0 51.200001 ... 56.109587 56.109587
|
||||
10 1408.0 64.138541 ... 64.902096 64.138541
|
||||
11 1536.0 80.430545 ... 76.106321 75.296679
|
||||
12 1664.0 63.372618 ... 62.492442 62.061463
|
||||
12 1664.0 62.929456 ... 62.061463 62.061463
|
||||
13 1792.0 72.983276 ... 69.810085 69.379162
|
||||
14 1920.0 68.435645 ... 67.764707 69.818184
|
||||
15 2048.0 73.584279 ... 75.234154 74.898285
|
||||
16 2176.0 83.500614 ... 81.143743 78.916269
|
||||
17 2304.0 68.056616 ... 73.501144 73.051599
|
||||
18 2432.0 71.125224 ... 80.269900 80.963875
|
||||
19 2560.0 77.833728 ... 76.920185 76.382283
|
||||
20 2688.0 80.027544 ... 79.524227 82.284288
|
||||
21 2816.0 83.392363 ... 79.587973 76.785575
|
||||
22 2944.0 82.509987 ... 79.230573 79.993627
|
||||
23 3072.0 81.589488 ... 83.761985 82.301023
|
||||
24 3200.0 84.768213 ... 89.385477 89.012517
|
||||
25 3328.0 80.617354 ... 80.707733 86.217120
|
||||
26 3456.0 81.518272 ... 85.223646 82.183044
|
||||
27 3584.0 84.033077 ... 93.564405 95.047985
|
||||
28 3712.0 86.267139 ... 88.015279 89.194055
|
||||
29 3840.0 84.874902 ... 88.402879 87.217666
|
||||
30 3968.0 92.442373 ... 87.850207 87.347124
|
||||
31 4096.0 93.531519 ... 85.926841 85.871865
|
||||
14 1920.0 69.120002 ... 70.892307 69.120002
|
||||
15 2048.0 73.584279 ... 74.898285 74.565406
|
||||
16 2176.0 83.155572 ... 80.817862 79.855747
|
||||
17 2304.0 68.446623 ... 72.828879 73.275679
|
||||
18 2432.0 71.305746 ... 82.388456 81.908060
|
||||
19 2560.0 78.019048 ... 77.283019 75.676673
|
||||
20 2688.0 83.552988 ... 83.552988 83.922689
|
||||
21 2816.0 81.827785 ... 77.330158 79.154642
|
||||
22 2944.0 81.166173 ... 77.747321 79.483304
|
||||
23 3072.0 79.863336 ... 82.661468 82.420822
|
||||
24 3200.0 83.660130 ... 90.395483 85.906037
|
||||
25 3328.0 83.226931 ... 87.368079 83.613586
|
||||
26 3456.0 80.220468 ... 81.600781 83.459178
|
||||
27 3584.0 87.466332 ... 92.887804 84.983685
|
||||
28 3712.0 84.159518 ... 83.178475 83.666116
|
||||
29 3840.0 83.591840 ... 84.228485 85.663823
|
||||
30 3968.0 91.885495 ... 84.680037 84.154440
|
||||
31 4096.0 89.181212 ... 90.260743 90.200084
|
||||
|
||||
[32 rows x 5 columns]
|
||||
|
||||
@@ -523,7 +511,7 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 2 minutes 30.126 seconds)
|
||||
**Total running time of the script:** ( 2 minutes 29.710 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
|
||||
|
@@ -5,12 +5,12 @@
|
||||
|
||||
Computation times
|
||||
=================
|
||||
**03:53.760** total execution time for **getting-started_tutorials** files:
|
||||
**03:53.315** total execution time for **getting-started_tutorials** files:
|
||||
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 02:30.126 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 02:29.710 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 01:12.654 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 01:12.618 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:10.981 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:10.987 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
|
@@ -319,10 +319,10 @@ for different problem sizes.</p>
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>vector-add-performance:
|
||||
size Triton Torch
|
||||
0 4096.0 9.600000 9.600000
|
||||
0 4096.0 9.540372 9.600000
|
||||
1 8192.0 19.200000 19.200000
|
||||
2 16384.0 38.400001 38.400001
|
||||
3 32768.0 63.999998 63.999998
|
||||
3 32768.0 76.800002 76.800002
|
||||
4 65536.0 127.999995 127.999995
|
||||
5 131072.0 219.428568 219.428568
|
||||
6 262144.0 341.333321 384.000001
|
||||
@@ -337,7 +337,7 @@ for different problem sizes.</p>
|
||||
15 134217728.0 851.577704 850.656574
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 10.981 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 10.987 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>
|
||||
|
@@ -391,11 +391,11 @@ We will then compare its performance against (1) <code class="code docutils lite
|
||||
3 640.0 682.666684 640.000002 160.000000
|
||||
4 768.0 702.171410 664.216187 163.839992
|
||||
.. ... ... ... ...
|
||||
93 12160.0 812.359066 405.755985 198.936606
|
||||
94 12288.0 812.429770 415.661740 199.197579
|
||||
95 12416.0 810.840807 412.149375 198.854847
|
||||
96 12544.0 810.925276 412.971190 199.111113
|
||||
97 12672.0 811.007961 412.097543 199.167004
|
||||
93 12160.0 812.359066 405.755985 199.038365
|
||||
94 12288.0 812.429770 415.661740 199.298541
|
||||
95 12416.0 810.840807 412.149375 198.954424
|
||||
96 12544.0 809.290334 412.546756 199.209928
|
||||
97 12672.0 809.389265 412.097543 199.167004
|
||||
|
||||
[98 rows x 4 columns]
|
||||
</pre></div>
|
||||
@@ -409,7 +409,7 @@ This means that – when temporary data is too large to fit entirely in the GPU
|
||||
Note that our Triton kernel is not only faster than PyTorch’s CUDA kernel, it is also <strong>easier to read, understand and maintain</strong>.</p></li>
|
||||
</ul>
|
||||
</div></blockquote>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 12.654 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 12.618 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>
|
||||
|
@@ -241,11 +241,9 @@ algorithm to multiply a (MxK) by a (KxN) matrix:</p>
|
||||
<div class="section" id="compute-kernel">
|
||||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<p>The above algorithm is, actually, fairly straightforward to implement in Triton.
|
||||
The main difficulty comes from the computation of the memory locations at which blocks</p>
|
||||
<blockquote>
|
||||
<div><p>of <code class="code docutils literal notranslate"><span class="pre">A</span></code> and <code class="code docutils literal notranslate"><span class="pre">B</span></code> must be read in the inner loop. For that, we need</p>
|
||||
</div></blockquote>
|
||||
<p>multi-dimensional pointer arithmetics.</p>
|
||||
The main difficulty comes from the computation of the memory locations at which blocks
|
||||
of <code class="code docutils literal notranslate"><span class="pre">A</span></code> and <code class="code docutils literal notranslate"><span class="pre">B</span></code> must be read in the inner loop. For that, we need
|
||||
multi-dimensional pointer arithmetics.</p>
|
||||
<div class="section" id="pointer-arithmetics">
|
||||
<h3>Pointer Arithmetics<a class="headerlink" href="#pointer-arithmetics" title="Permalink to this headline">¶</a></h3>
|
||||
<p>For a row-major 2D tensor <code class="code docutils literal notranslate"><span class="pre">X</span></code>, the memory location of <code class="code docutils literal notranslate"><span class="pre">X[i,</span> <span class="pre">j]</span></code> is given b
|
||||
@@ -282,11 +280,9 @@ Therefore, blocks of pointers for <code class="code docutils literal notranslate
|
||||
</div>
|
||||
<div class="section" id="l2-cache-optimizations">
|
||||
<h3>L2 Cache Optimizations<a class="headerlink" href="#l2-cache-optimizations" title="Permalink to this headline">¶</a></h3>
|
||||
<dl class="simple">
|
||||
<dt>As mentioned above, each program instance computes a <code class="code docutils literal notranslate"><span class="pre">[BLOCK_SIZE_M,</span> <span class="pre">BLOCK_SIZE_N]</span></code></dt><dd><p>block of <code class="code docutils literal notranslate"><span class="pre">C</span></code>.</p>
|
||||
</dd>
|
||||
</dl>
|
||||
<p>It is important to remember that the order in which these blocks are computed does
|
||||
<p>As mentioned above, each program instance computes a <code class="code docutils literal notranslate"><span class="pre">[BLOCK_SIZE_M,</span> <span class="pre">BLOCK_SIZE_N]</span></code>
|
||||
block of <code class="code docutils literal notranslate"><span class="pre">C</span></code>.
|
||||
It is important to remember that the order in which these blocks are computed does
|
||||
matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
|
||||
a simple row-major ordering</p>
|
||||
<blockquote>
|
||||
@@ -313,17 +309,15 @@ switching to the next column:</p>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div></blockquote>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># For example, in the following matmul where each matrix is 9 blocks by 9 blocks,</span>
|
||||
<span class="c1"># we can see that if we compute the output in row-major ordering, we need to load 90</span>
|
||||
<span class="c1"># blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped</span>
|
||||
<span class="c1"># ordering, we only need to load 54 blocks.</span>
|
||||
<span class="c1"># .. image:: grouped_vs_row_major_ordering.png</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># In practice, this can improve the performance of our matrix multiplication kernel by</span>
|
||||
<span class="c1"># more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</span>
|
||||
<span class="c1">#</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
||||
we can see that if we compute the output in row-major ordering, we need to load 90
|
||||
blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
|
||||
ordering, we only need to load 54 blocks.</p>
|
||||
<blockquote>
|
||||
<div><img alt="../../_images/grouped_vs_row_major_ordering.png" src="../../_images/grouped_vs_row_major_ordering.png" />
|
||||
</div></blockquote>
|
||||
<p>In practice, this can improve the performance of our matrix multiplication kernel by
|
||||
more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section" id="final-result">
|
||||
@@ -501,8 +495,8 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
|
||||
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||||
<span class="n">triton_output</span> <span class="o">=</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||||
<span class="n">torch_output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">triton_output</span><span class="si">=}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">torch_output</span><span class="si">=}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"triton_output=</span><span class="si">{</span><span class="n">triton_output</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"torch_output=</span><span class="si">{</span><span class="n">torch_output</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">triton_output</span><span class="p">,</span> <span class="n">torch_output</span><span class="p">):</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="s2">"✅ Triton and Torch match"</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
@@ -582,41 +576,41 @@ torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -3
|
||||
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
||||
0 128.0 0.455111 ... 0.512000 0.512000
|
||||
1 256.0 2.730667 ... 3.276800 2.978909
|
||||
2 384.0 7.372800 ... 8.507077 8.507077
|
||||
3 512.0 14.563555 ... 16.384000 15.420235
|
||||
2 384.0 7.372800 ... 7.899428 7.899428
|
||||
3 512.0 14.563555 ... 16.384000 16.384000
|
||||
4 640.0 22.260869 ... 24.380953 24.380953
|
||||
5 768.0 32.768000 ... 34.028308 34.028308
|
||||
6 896.0 39.025776 ... 40.140799 35.123201
|
||||
6 896.0 39.025776 ... 40.140799 35.150663
|
||||
7 1024.0 49.932191 ... 52.428801 52.428801
|
||||
8 1152.0 44.566925 ... 46.656000 45.938215
|
||||
8 1152.0 44.566925 ... 46.656000 46.656000
|
||||
9 1280.0 51.200001 ... 56.109587 56.109587
|
||||
10 1408.0 64.138541 ... 64.902096 64.138541
|
||||
11 1536.0 80.430545 ... 76.106321 75.296679
|
||||
12 1664.0 63.372618 ... 62.492442 62.061463
|
||||
12 1664.0 62.929456 ... 62.061463 62.061463
|
||||
13 1792.0 72.983276 ... 69.810085 69.379162
|
||||
14 1920.0 68.435645 ... 67.764707 69.818184
|
||||
15 2048.0 73.584279 ... 75.234154 74.898285
|
||||
16 2176.0 83.500614 ... 81.143743 78.916269
|
||||
17 2304.0 68.056616 ... 73.501144 73.051599
|
||||
18 2432.0 71.125224 ... 80.269900 80.963875
|
||||
19 2560.0 77.833728 ... 76.920185 76.382283
|
||||
20 2688.0 80.027544 ... 79.524227 82.284288
|
||||
21 2816.0 83.392363 ... 79.587973 76.785575
|
||||
22 2944.0 82.509987 ... 79.230573 79.993627
|
||||
23 3072.0 81.589488 ... 83.761985 82.301023
|
||||
24 3200.0 84.768213 ... 89.385477 89.012517
|
||||
25 3328.0 80.617354 ... 80.707733 86.217120
|
||||
26 3456.0 81.518272 ... 85.223646 82.183044
|
||||
27 3584.0 84.033077 ... 93.564405 95.047985
|
||||
28 3712.0 86.267139 ... 88.015279 89.194055
|
||||
29 3840.0 84.874902 ... 88.402879 87.217666
|
||||
30 3968.0 92.442373 ... 87.850207 87.347124
|
||||
31 4096.0 93.531519 ... 85.926841 85.871865
|
||||
14 1920.0 69.120002 ... 70.892307 69.120002
|
||||
15 2048.0 73.584279 ... 74.898285 74.565406
|
||||
16 2176.0 83.155572 ... 80.817862 79.855747
|
||||
17 2304.0 68.446623 ... 72.828879 73.275679
|
||||
18 2432.0 71.305746 ... 82.388456 81.908060
|
||||
19 2560.0 78.019048 ... 77.283019 75.676673
|
||||
20 2688.0 83.552988 ... 83.552988 83.922689
|
||||
21 2816.0 81.827785 ... 77.330158 79.154642
|
||||
22 2944.0 81.166173 ... 77.747321 79.483304
|
||||
23 3072.0 79.863336 ... 82.661468 82.420822
|
||||
24 3200.0 83.660130 ... 90.395483 85.906037
|
||||
25 3328.0 83.226931 ... 87.368079 83.613586
|
||||
26 3456.0 80.220468 ... 81.600781 83.459178
|
||||
27 3584.0 87.466332 ... 92.887804 84.983685
|
||||
28 3712.0 84.159518 ... 83.178475 83.666116
|
||||
29 3840.0 83.591840 ... 84.228485 85.663823
|
||||
30 3968.0 91.885495 ... 84.680037 84.154440
|
||||
31 4096.0 89.181212 ... 90.260743 90.200084
|
||||
|
||||
[32 rows x 5 columns]
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes 30.126 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes 29.710 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
|
||||
|
@@ -174,7 +174,7 @@
|
||||
|
||||
<div class="section" id="computation-times">
|
||||
<span id="sphx-glr-getting-started-tutorials-sg-execution-times"></span><h1>Computation times<a class="headerlink" href="#computation-times" title="Permalink to this headline">¶</a></h1>
|
||||
<p><strong>03:53.760</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
|
||||
<p><strong>03:53.315</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
|
||||
<table class="docutils align-default">
|
||||
<colgroup>
|
||||
<col style="width: 85%" />
|
||||
@@ -183,15 +183,15 @@
|
||||
</colgroup>
|
||||
<tbody>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">Matrix Multiplication</span></a> (<code class="docutils literal notranslate"><span class="pre">03-matrix-multiplication.py</span></code>)</p></td>
|
||||
<td><p>02:30.126</p></td>
|
||||
<td><p>02:29.710</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td>
|
||||
<td><p>01:12.654</p></td>
|
||||
<td><p>01:12.618</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td>
|
||||
<td><p>00:10.981</p></td>
|
||||
<td><p>00:10.987</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
|