|
|
|
@@ -189,7 +189,7 @@ Auto-Tuning
|
|
|
|
|
|
|
|
|
|
In order to use Triton's built-in auto-tuner in the above kernel, we need to define a list of :code:`triton.config` objects. that can be constructed as follows:
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 170-217
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 170-185
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
@@ -198,46 +198,14 @@ In order to use Triton's built-in auto-tuner in the above kernel, we need to def
|
|
|
|
|
import triton
|
|
|
|
|
|
|
|
|
|
autotune_configs = [
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
"MB": "128",
|
|
|
|
|
"NB": "128",
|
|
|
|
|
"KB": "32"
|
|
|
|
|
}, num_warps=4),
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
'MB': '64',
|
|
|
|
|
'NB': '128',
|
|
|
|
|
'KB': '32'
|
|
|
|
|
}, num_warps=4),
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
'MB': '128',
|
|
|
|
|
'NB': '64',
|
|
|
|
|
'KB': '32'
|
|
|
|
|
}, num_warps=4),
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
'MB': '64',
|
|
|
|
|
'NB': '64',
|
|
|
|
|
'KB': '64'
|
|
|
|
|
}, num_warps=4),
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
'MB': '32',
|
|
|
|
|
'NB': '128',
|
|
|
|
|
'KB': '64'
|
|
|
|
|
}, num_warps=4),
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
'MB': '128',
|
|
|
|
|
'NB': '32',
|
|
|
|
|
'KB': '64'
|
|
|
|
|
}, num_warps=4),
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
'MB': '64',
|
|
|
|
|
'NB': '32',
|
|
|
|
|
'KB': '64'
|
|
|
|
|
}, num_warps=2),
|
|
|
|
|
triton.config(defines={
|
|
|
|
|
'MB': '32',
|
|
|
|
|
'NB': '64',
|
|
|
|
|
'KB': '64'
|
|
|
|
|
}, num_warps=2)
|
|
|
|
|
triton.config(defines={"MB": "128", "NB": "128", "KB": "32"}, num_warps=4),
|
|
|
|
|
triton.config(defines={'MB': '64', 'NB': '128', 'KB': '32'}, num_warps=4),
|
|
|
|
|
triton.config(defines={'MB': '128', 'NB': '64', 'KB': '32'}, num_warps=4),
|
|
|
|
|
triton.config(defines={'MB': '64', 'NB': '64', 'KB': '64'}, num_warps=4),
|
|
|
|
|
triton.config(defines={'MB': '32', 'NB': '128', 'KB': '64'}, num_warps=4),
|
|
|
|
|
triton.config(defines={'MB': '128', 'NB': '32', 'KB': '64'}, num_warps=4),
|
|
|
|
|
triton.config(defines={'MB': '64', 'NB': '32', 'KB': '64'}, num_warps=2),
|
|
|
|
|
triton.config(defines={'MB': '32', 'NB': '64', 'KB': '64'}, num_warps=2)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -247,12 +215,12 @@ In order to use Triton's built-in auto-tuner in the above kernel, we need to def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 218-220
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 186-188
|
|
|
|
|
|
|
|
|
|
we also need to define a list of :code:`string` (i.e., "autotuning key") that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in.
|
|
|
|
|
Here, we want to re-tune our kernel only when the shape of input matrices changes.
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 220-223
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 188-191
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
@@ -266,11 +234,11 @@ Here, we want to re-tune our kernel only when the shape of input matrices change
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 224-225
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 192-193
|
|
|
|
|
|
|
|
|
|
We can now create an auto-tuned kernel by passing the `autotune_configs` and `autotune_key` lists to the constructor of the :code:`triton.kernel` class.
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 225-270
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 193-238
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
@@ -326,7 +294,7 @@ We can now create an auto-tuned kernel by passing the `autotune_configs` and `au
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 271-276
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 239-244
|
|
|
|
|
|
|
|
|
|
Autograd Function
|
|
|
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
@@ -334,7 +302,7 @@ Autograd Function
|
|
|
|
|
Now we are ready to expose our auto-tuned kernel as a `torch.autograd.Function`.
|
|
|
|
|
To do so, we just need to define a `forward` function that takes a two tensors as input and returns a tensor as output.
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 276-297
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 244-265
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
@@ -366,7 +334,7 @@ To do so, we just need to define a `forward` function that takes a two tensors a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 298-303
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 266-271
|
|
|
|
|
|
|
|
|
|
Unit Test
|
|
|
|
|
-----------
|
|
|
|
@@ -374,7 +342,7 @@ Unit Test
|
|
|
|
|
We can test our custom matrix multiplication operation against cuBLAS (i.e., :code:`torch.matmul`).
|
|
|
|
|
Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torch.allclose` to account for the fact that we are comparing FP16 tensors.
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 303-312
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 271-280
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
@@ -397,28 +365,28 @@ Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torc
|
|
|
|
|
|
|
|
|
|
.. code-block:: none
|
|
|
|
|
|
|
|
|
|
tensor([[199.0000, 199.1250, 195.8750, ..., 190.6250, 200.7500, 186.3750],
|
|
|
|
|
[196.1250, 201.6250, 197.6250, ..., 189.6250, 197.7500, 190.0000],
|
|
|
|
|
[198.0000, 196.6250, 200.1250, ..., 198.6250, 199.7500, 190.8750],
|
|
|
|
|
tensor([[199.6250, 198.0000, 195.0000, ..., 186.0000, 193.6250, 202.1250],
|
|
|
|
|
[192.6250, 193.6250, 190.7500, ..., 184.2500, 191.2500, 192.1250],
|
|
|
|
|
[192.3750, 196.6250, 188.8750, ..., 185.5000, 188.7500, 191.8750],
|
|
|
|
|
...,
|
|
|
|
|
[190.3750, 192.0000, 190.5000, ..., 187.0000, 191.7500, 180.8750],
|
|
|
|
|
[185.2500, 187.6250, 181.2500, ..., 185.1250, 188.2500, 175.5000],
|
|
|
|
|
[191.6250, 191.6250, 194.2500, ..., 188.2500, 192.1250, 182.0000]],
|
|
|
|
|
[196.6250, 199.8750, 196.1250, ..., 182.6250, 194.5000, 200.8750],
|
|
|
|
|
[199.2500, 200.3750, 191.7500, ..., 186.8750, 192.8750, 193.5000],
|
|
|
|
|
[193.5000, 195.2500, 194.1250, ..., 188.3750, 192.6250, 198.3750]],
|
|
|
|
|
device='cuda:0', dtype=torch.float16)
|
|
|
|
|
tensor([[199.0000, 199.1250, 195.8750, ..., 190.6250, 200.7500, 186.3750],
|
|
|
|
|
[196.1250, 201.6250, 197.6250, ..., 189.6250, 197.7500, 190.0000],
|
|
|
|
|
[198.0000, 196.6250, 200.1250, ..., 198.6250, 199.7500, 190.8750],
|
|
|
|
|
tensor([[199.6250, 198.0000, 195.0000, ..., 186.0000, 193.6250, 202.1250],
|
|
|
|
|
[192.6250, 193.6250, 190.7500, ..., 184.2500, 191.2500, 192.1250],
|
|
|
|
|
[192.3750, 196.6250, 188.8750, ..., 185.5000, 188.7500, 191.8750],
|
|
|
|
|
...,
|
|
|
|
|
[190.3750, 192.0000, 190.5000, ..., 187.0000, 191.7500, 180.8750],
|
|
|
|
|
[185.2500, 187.6250, 181.2500, ..., 185.1250, 188.2500, 175.5000],
|
|
|
|
|
[191.6250, 191.6250, 194.2500, ..., 188.2500, 192.1250, 182.0000]],
|
|
|
|
|
[196.6250, 199.8750, 196.1250, ..., 182.6250, 194.5000, 200.8750],
|
|
|
|
|
[199.2500, 200.3750, 191.7500, ..., 186.8750, 192.8750, 193.5000],
|
|
|
|
|
[193.5000, 195.2500, 194.1250, ..., 188.3750, 192.6250, 198.3750]],
|
|
|
|
|
device='cuda:0', dtype=torch.float16)
|
|
|
|
|
True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 313-359
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 281-327
|
|
|
|
|
|
|
|
|
|
Benchmark
|
|
|
|
|
--------------
|
|
|
|
@@ -432,7 +400,7 @@ To install CUTLASS, you need a recent version of cmake:
|
|
|
|
|
|
|
|
|
|
.. code-block:: bash
|
|
|
|
|
|
|
|
|
|
cd /tmp/
|
|
|
|
|
cd /path/to/cutlass/
|
|
|
|
|
git clone https://github.com/NVIDIA/cutlass.git
|
|
|
|
|
cd cutlass
|
|
|
|
|
mkdir build
|
|
|
|
@@ -461,13 +429,13 @@ To re-install Triton with the updated CUTLASS bindings, run the following comman
|
|
|
|
|
.. code-block:: bash
|
|
|
|
|
|
|
|
|
|
export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
|
|
|
|
|
export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/
|
|
|
|
|
export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/a
|
|
|
|
|
pip uninstall -y triton
|
|
|
|
|
pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
|
|
|
|
|
|
|
|
|
|
Which we can test as follows:
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 359-365
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 327-333
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
@@ -487,20 +455,20 @@ Which we can test as follows:
|
|
|
|
|
|
|
|
|
|
.. code-block:: none
|
|
|
|
|
|
|
|
|
|
tensor([[199.0000, 199.1250, 195.8750, ..., 190.6250, 200.7500, 186.3750],
|
|
|
|
|
[196.1250, 201.6250, 197.6250, ..., 189.6250, 197.7500, 190.0000],
|
|
|
|
|
[198.0000, 196.6250, 200.1250, ..., 198.6250, 199.7500, 190.8750],
|
|
|
|
|
tensor([[199.6250, 198.0000, 195.0000, ..., 186.0000, 193.6250, 202.1250],
|
|
|
|
|
[192.6250, 193.6250, 190.7500, ..., 184.2500, 191.2500, 192.1250],
|
|
|
|
|
[192.3750, 196.6250, 188.8750, ..., 185.5000, 188.7500, 191.8750],
|
|
|
|
|
...,
|
|
|
|
|
[190.3750, 192.0000, 190.5000, ..., 187.0000, 191.7500, 180.8750],
|
|
|
|
|
[185.2500, 187.6250, 181.2500, ..., 185.1250, 188.2500, 175.5000],
|
|
|
|
|
[191.6250, 191.6250, 194.2500, ..., 188.2500, 192.1250, 182.0000]],
|
|
|
|
|
[196.6250, 199.8750, 196.1250, ..., 182.6250, 194.5000, 200.8750],
|
|
|
|
|
[199.2500, 200.3750, 191.7500, ..., 186.8750, 192.8750, 193.5000],
|
|
|
|
|
[193.5000, 195.2500, 194.1250, ..., 188.3750, 192.6250, 198.3750]],
|
|
|
|
|
device='cuda:0', dtype=torch.float16)
|
|
|
|
|
True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 366-371
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 334-339
|
|
|
|
|
|
|
|
|
|
Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready.
|
|
|
|
|
|
|
|
|
@@ -508,7 +476,7 @@ Square Matrix Performance
|
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 371-400
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 339-368
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
@@ -552,14 +520,14 @@ We can now compare the performance of our kernel against CUTLASS. Here we focus
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 401-401
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 369-369
|
|
|
|
|
|
|
|
|
|
As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. rst-class:: sphx-glr-timing
|
|
|
|
|
|
|
|
|
|
**Total running time of the script:** ( 1 minutes 10.094 seconds)
|
|
|
|
|
**Total running time of the script:** ( 1 minutes 6.502 seconds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
|
|
|
|
|