|
|
|
@@ -189,7 +189,7 @@ Auto-Tuning
|
|
|
|
|
|
|
|
|
|
In order to use Triton's built-in auto-tuner in the above kernel, we need to define a list of :code:`triton.config` objects. that can be constructed as follows:
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 170-185
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 170-217
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
@@ -198,14 +198,46 @@ In order to use Triton's built-in auto-tuner in the above kernel, we need to def
|
|
|
|
|
import triton
|
|
|
|
|
|
|
|
|
|
autotune_configs = [
|
|
|
|
|
triton.config(defines={"MB": "128", "NB": "128", "KB": "32"}, num_warps=4),
|
|
|
|
|
triton.config(defines={'MB': '64', 'NB': '128', 'KB': '32'}, num_warps=4),
|
|
|
|
|
triton.config(defines={'MB': '128', 'NB': '64', 'KB': '32'}, num_warps=4),
|
|
|
|
|
triton.config(defines={'MB': '64', 'NB': '64', 'KB': '64'}, num_warps=4),
|
|
|
|
|
triton.config(defines={'MB': '32', 'NB': '128', 'KB': '64'}, num_warps=4),
|
|
|
|
|
triton.config(defines={'MB': '128', 'NB': '32', 'KB': '64'}, num_warps=4),
|
|
|
|
|
triton.config(defines={'MB': '64', 'NB': '32', 'KB': '64'}, num_warps=2),
|
|
|
|
|
triton.config(defines={'MB': '32', 'NB': '64', 'KB': '64'}, num_warps=2)
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
"MB": "128",
|
|
|
|
|
"NB": "128",
|
|
|
|
|
"KB": "32"
|
|
|
|
|
}, num_warps=4),
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
'MB': '64',
|
|
|
|
|
'NB': '128',
|
|
|
|
|
'KB': '32'
|
|
|
|
|
}, num_warps=4),
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
'MB': '128',
|
|
|
|
|
'NB': '64',
|
|
|
|
|
'KB': '32'
|
|
|
|
|
}, num_warps=4),
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
'MB': '64',
|
|
|
|
|
'NB': '64',
|
|
|
|
|
'KB': '64'
|
|
|
|
|
}, num_warps=4),
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
'MB': '32',
|
|
|
|
|
'NB': '128',
|
|
|
|
|
'KB': '64'
|
|
|
|
|
}, num_warps=4),
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
'MB': '128',
|
|
|
|
|
'NB': '32',
|
|
|
|
|
'KB': '64'
|
|
|
|
|
}, num_warps=4),
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
'MB': '64',
|
|
|
|
|
'NB': '32',
|
|
|
|
|
'KB': '64'
|
|
|
|
|
}, num_warps=2),
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
'MB': '32',
|
|
|
|
|
'NB': '64',
|
|
|
|
|
'KB': '64'
|
|
|
|
|
}, num_warps=2)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -215,12 +247,12 @@ In order to use Triton's built-in auto-tuner in the above kernel, we need to def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 186-188
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 218-220
|
|
|
|
|
|
|
|
|
|
we also need to define a list of :code:`string` (i.e., "autotuning key") that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in.
|
|
|
|
|
Here, we want to re-tune our kernel only when the shape of input matrices changes.
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 188-191
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 220-223
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
@@ -234,11 +266,11 @@ Here, we want to re-tune our kernel only when the shape of input matrices change
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 192-193
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 224-225
|
|
|
|
|
|
|
|
|
|
We can now create an auto-tuned kernel by passing the `autotune_configs` and `autotune_key` lists to the constructor of the :code:`triton.kernel` class.
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 193-238
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 225-270
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
@@ -294,7 +326,7 @@ We can now create an auto-tuned kernel by passing the `autotune_configs` and `au
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 239-244
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 271-276
|
|
|
|
|
|
|
|
|
|
Autograd Function
|
|
|
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
@@ -302,7 +334,7 @@ Autograd Function
|
|
|
|
|
Now we are ready to expose our auto-tuned kernel as a `torch.autograd.Function`.
|
|
|
|
|
To do so, we just need to define a `forward` function that takes a two tensors as input and returns a tensor as output.
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 244-265
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 276-297
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
@@ -334,7 +366,7 @@ To do so, we just need to define a `forward` function that takes a two tensors a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 266-271
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 298-303
|
|
|
|
|
|
|
|
|
|
Unit Test
|
|
|
|
|
-----------
|
|
|
|
@@ -342,7 +374,7 @@ Unit Test
|
|
|
|
|
We can test our custom matrix multiplication operation against cuBLAS (i.e., :code:`torch.matmul`).
|
|
|
|
|
Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torch.allclose` to account for the fact that we are comparing FP16 tensors.
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 271-280
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 303-312
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
@@ -365,28 +397,28 @@ Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torc
|
|
|
|
|
|
|
|
|
|
.. code-block:: none
|
|
|
|
|
|
|
|
|
|
tensor([[186.7500, 195.3750, 196.1250, ..., 197.0000, 199.1250, 200.1250],
|
|
|
|
|
[181.8750, 181.1250, 187.2500, ..., 191.5000, 192.3750, 185.1250],
|
|
|
|
|
[183.0000, 192.7500, 194.3750, ..., 200.3750, 195.1250, 193.5000],
|
|
|
|
|
tensor([[199.0000, 199.1250, 195.8750, ..., 190.6250, 200.7500, 186.3750],
|
|
|
|
|
[196.1250, 201.6250, 197.6250, ..., 189.6250, 197.7500, 190.0000],
|
|
|
|
|
[198.0000, 196.6250, 200.1250, ..., 198.6250, 199.7500, 190.8750],
|
|
|
|
|
...,
|
|
|
|
|
[176.1250, 183.0000, 182.1250, ..., 184.7500, 190.8750, 187.5000],
|
|
|
|
|
[182.0000, 181.8750, 183.2500, ..., 187.8750, 190.5000, 186.2500],
|
|
|
|
|
[173.0000, 182.3750, 187.2500, ..., 191.2500, 187.6250, 184.5000]],
|
|
|
|
|
[190.3750, 192.0000, 190.5000, ..., 187.0000, 191.7500, 180.8750],
|
|
|
|
|
[185.2500, 187.6250, 181.2500, ..., 185.1250, 188.2500, 175.5000],
|
|
|
|
|
[191.6250, 191.6250, 194.2500, ..., 188.2500, 192.1250, 182.0000]],
|
|
|
|
|
device='cuda:0', dtype=torch.float16)
|
|
|
|
|
tensor([[186.7500, 195.3750, 196.1250, ..., 197.0000, 199.1250, 200.1250],
|
|
|
|
|
[181.8750, 181.1250, 187.2500, ..., 191.5000, 192.3750, 185.1250],
|
|
|
|
|
[183.0000, 192.7500, 194.3750, ..., 200.3750, 195.1250, 193.5000],
|
|
|
|
|
tensor([[199.0000, 199.1250, 195.8750, ..., 190.6250, 200.7500, 186.3750],
|
|
|
|
|
[196.1250, 201.6250, 197.6250, ..., 189.6250, 197.7500, 190.0000],
|
|
|
|
|
[198.0000, 196.6250, 200.1250, ..., 198.6250, 199.7500, 190.8750],
|
|
|
|
|
...,
|
|
|
|
|
[176.1250, 183.0000, 182.1250, ..., 184.7500, 190.8750, 187.5000],
|
|
|
|
|
[182.0000, 181.8750, 183.2500, ..., 187.8750, 190.5000, 186.2500],
|
|
|
|
|
[173.0000, 182.3750, 187.2500, ..., 191.2500, 187.6250, 184.5000]],
|
|
|
|
|
[190.3750, 192.0000, 190.5000, ..., 187.0000, 191.7500, 180.8750],
|
|
|
|
|
[185.2500, 187.6250, 181.2500, ..., 185.1250, 188.2500, 175.5000],
|
|
|
|
|
[191.6250, 191.6250, 194.2500, ..., 188.2500, 192.1250, 182.0000]],
|
|
|
|
|
device='cuda:0', dtype=torch.float16)
|
|
|
|
|
True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 281-327
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 313-359
|
|
|
|
|
|
|
|
|
|
Benchmark
|
|
|
|
|
--------------
|
|
|
|
@@ -400,7 +432,7 @@ To install CUTLASS, you need a recent version of cmake:
|
|
|
|
|
|
|
|
|
|
.. code-block:: bash
|
|
|
|
|
|
|
|
|
|
cd /path/to/cutlass/
|
|
|
|
|
cd /tmp/
|
|
|
|
|
git clone https://github.com/NVIDIA/cutlass.git
|
|
|
|
|
cd cutlass
|
|
|
|
|
mkdir build
|
|
|
|
@@ -429,13 +461,13 @@ To re-install Triton with the updated CUTLASS bindings, run the following comman
|
|
|
|
|
.. code-block:: bash
|
|
|
|
|
|
|
|
|
|
export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
|
|
|
|
|
export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/a
|
|
|
|
|
export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/
|
|
|
|
|
pip uninstall -y triton
|
|
|
|
|
pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
|
|
|
|
|
|
|
|
|
|
Which we can test as follows:
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 327-333
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 359-365
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
@@ -455,20 +487,20 @@ Which we can test as follows:
|
|
|
|
|
|
|
|
|
|
.. code-block:: none
|
|
|
|
|
|
|
|
|
|
tensor([[186.7500, 195.3750, 196.1250, ..., 197.0000, 199.1250, 200.1250],
|
|
|
|
|
[181.8750, 181.1250, 187.2500, ..., 191.5000, 192.3750, 185.1250],
|
|
|
|
|
[183.0000, 192.7500, 194.3750, ..., 200.3750, 195.1250, 193.5000],
|
|
|
|
|
tensor([[199.0000, 199.1250, 195.8750, ..., 190.6250, 200.7500, 186.3750],
|
|
|
|
|
[196.1250, 201.6250, 197.6250, ..., 189.6250, 197.7500, 190.0000],
|
|
|
|
|
[198.0000, 196.6250, 200.1250, ..., 198.6250, 199.7500, 190.8750],
|
|
|
|
|
...,
|
|
|
|
|
[176.1250, 183.0000, 182.1250, ..., 184.7500, 190.8750, 187.5000],
|
|
|
|
|
[182.0000, 181.8750, 183.2500, ..., 187.8750, 190.5000, 186.2500],
|
|
|
|
|
[173.0000, 182.3750, 187.2500, ..., 191.2500, 187.6250, 184.5000]],
|
|
|
|
|
[190.3750, 192.0000, 190.5000, ..., 187.0000, 191.7500, 180.8750],
|
|
|
|
|
[185.2500, 187.6250, 181.2500, ..., 185.1250, 188.2500, 175.5000],
|
|
|
|
|
[191.6250, 191.6250, 194.2500, ..., 188.2500, 192.1250, 182.0000]],
|
|
|
|
|
device='cuda:0', dtype=torch.float16)
|
|
|
|
|
True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 334-339
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 366-371
|
|
|
|
|
|
|
|
|
|
Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready.
|
|
|
|
|
|
|
|
|
@@ -476,7 +508,7 @@ Square Matrix Performance
|
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 339-368
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 371-400
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
@@ -520,14 +552,14 @@ We can now compare the performance of our kernel against CUTLASS. Here we focus
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 369-369
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 401-401
|
|
|
|
|
|
|
|
|
|
As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. rst-class:: sphx-glr-timing
|
|
|
|
|
|
|
|
|
|
**Total running time of the script:** ( 1 minutes 10.181 seconds)
|
|
|
|
|
**Total running time of the script:** ( 1 minutes 10.094 seconds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
|
|
|
|
|