2021-02-08 12:16:41 -08:00
|
|
|
import itertools
|
2022-01-06 14:34:17 -08:00
|
|
|
|
|
|
|
import pytest
|
2021-02-08 12:16:41 -08:00
|
|
|
import torch
|
|
|
|
|
2022-01-06 14:34:17 -08:00
|
|
|
import triton
|
|
|
|
|
2021-04-20 22:29:40 -04:00
|
|
|
|
2021-02-08 12:16:41 -08:00
|
|
|
@pytest.mark.parametrize(
|
2021-06-21 14:25:13 +08:00
|
|
|
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
|
2021-04-20 22:29:40 -04:00
|
|
|
itertools.chain(
|
|
|
|
*[
|
|
|
|
[
|
|
|
|
# 1 warp
|
2021-06-21 14:25:13 +08:00
|
|
|
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
2021-04-20 22:29:40 -04:00
|
|
|
# 2 warp
|
2021-06-21 14:25:13 +08:00
|
|
|
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
2021-04-20 22:29:40 -04:00
|
|
|
# 4 warp
|
2021-06-21 14:25:13 +08:00
|
|
|
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
2021-04-20 22:29:40 -04:00
|
|
|
# 8 warp
|
2021-06-21 14:25:13 +08:00
|
|
|
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
# split-k
|
|
|
|
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
# variable input
|
|
|
|
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
|
|
|
|
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
|
|
|
|
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
|
|
|
|
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
|
2022-01-14 15:38:32 +08:00
|
|
|
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
|
2021-06-21 14:25:13 +08:00
|
|
|
],
|
|
|
|
# n-stage
|
|
|
|
*[
|
|
|
|
[
|
|
|
|
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
|
|
|
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
|
|
|
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
|
|
|
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
|
|
|
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
|
|
|
|
# split-k
|
|
|
|
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
|
|
|
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
|
2022-01-14 15:38:32 +08:00
|
|
|
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
|
2021-04-20 22:29:40 -04:00
|
|
|
]
|
|
|
|
),
|
2021-02-21 15:19:39 -08:00
|
|
|
)
|
2021-06-21 14:25:13 +08:00
|
|
|
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
2022-12-21 01:30:50 -08:00
|
|
|
capability = torch.cuda.get_device_capability()
|
|
|
|
if capability[0] < 7:
|
Better NVIDIA Pascal GPU Support (#827)
This PR clarifies which features are supported on P100 via its tests,
though Pascal is not officially and fully supported by Triton.
## What this PR does
- Skip unsupported tests on P100.
- Atomic RMW
- `tl.dot()` (perhaps not all patterns, but basically most `tl.dot()`
tests do not work on P100).
- Add an explicit error if shared memory size >= 64K on P100.
- Otherwise it causes `Invalid CUDA argument` error at
`cuLaunchKernel()`, but this error is not very straightforward to
understand. Instead of this generic CUDA argument error, this PR makes
Triton show an error during codegen when `sm < 70`. This check happens
in C/C++ so won't add an overhead in Triton's Python runtime.
- 3 tests (see below) are currently failing, but these are not marked as
skipped because any codegen update in the future can change the kernel
size of the other tests.
- This change won't affect Triton-MLIR. Hopefully Triton-MLIR's generic
`tl.dot()` implementation would support P100.
Importantly, Triton passed all the other tests on P100. Though this
support is not official, it is great for, for example, PyTorch's
TorchDynamo/Inductor, which can use Triton (without `tl.dot()`) for its
backend (https://github.com/pytorch/torchdynamo/issues/1591).
### Results on P100 (Google Cloud)
```sh
$ pytest test/unit
...
================================================================================== short test summary info ==================================================================================
FAILED test/unit/language/test_core.py::test_reduce2d[argmin-float32-shape99-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_reduce2d[argmax-float32-shape113-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_permute[float32-shape5-perm5] - RuntimeError: Device does not support shared memory of 67584bytes
================================================================== 3 failed, 3824 passed, 952 skipped in 470.90s (0:07:50) ==================================================================
```
<details><summary> <b>Environment Details (collapsed)</b></summary>
<p>
### VM details (Google Cloud)
https://cloud.google.com/
```
# You need a paid account (free trial does not cover GPUs)
Google Cloud -> New Project -> Compute-Engine -> VM Instance
Machine:
GPU: NVIDIA Tesla P100 x 1
CPU: 2 vCPUs, 7.5GB memory
Boot disk:
OS: Ubuntu 18.04 LTS
Disk: 40GB (cannot build Triton on the default 10GB disk)
- When I tried, about $1.2 per hour.
- US instances were full when I tried. I used Asia or Australia.
- Needed a paid account (GPU is not covered by free trial)
- Needed quota request for any GPU instance (by default, no GPU instance is allowed). Needed to wait an hour for approval
```
### Reproducer
```sh
## 1. Install CUDA and a driver
# Update the apt key (https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/)
sudo apt-key del 7fa2af80
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb
# Download CUDA as instructed
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin
sudo mv cuda-ubuntu1804.pin /etc/apt/preferences.d/cuda-repository-pin-600
sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/ /"
sudo apt-get update
sudo apt-get -y install cuda
# Are you using P100?
nvidia-smi | grep "Tesla P100"
## 2. Setup the build environment
sudo apt update
sudo apt install -y build-essential wget git libz-dev
wget https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh
bash Anaconda3-2022.05-Linux-x86_64.sh -b -p $(pwd)/anaconda3
eval "$($(pwd)/anaconda3/bin/conda shell.bash hook)"
conda create -y --name triton_base
conda activate triton_base
conda install -y cmake setuptools
## 3. Build Triton
git clone https://github.com/openai/triton.git
cd triton/python
pip3 install -e '.[tests]'
## 4. Test
pytest test/unit
```
### Environment
```sh
$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05 Driver Version: 520.61.05 CUDA Version: 11.8 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla P100-PCIE... On | 00000000:00:04.0 Off | 0 |
| N/A 36C P0 25W / 250W | 0MiB / 16384MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
```
</p></details>
2022-11-03 00:11:52 -07:00
|
|
|
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
2022-12-21 01:30:50 -08:00
|
|
|
if capability[0] < 8 and DTYPE == "bfloat16":
|
2022-01-14 15:38:32 +08:00
|
|
|
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
|
|
|
if DTYPE == "bfloat16" and SPLIT_K != 1:
|
|
|
|
pytest.skip("bfloat16 matmuls don't allow split_k for now")
|
2021-02-08 12:16:41 -08:00
|
|
|
torch.manual_seed(0)
|
2021-04-20 22:29:40 -04:00
|
|
|
# nuke kernel decorators -- will set meta-parameters manually
|
2021-10-30 00:32:58 -07:00
|
|
|
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
2021-11-22 03:20:59 +08:00
|
|
|
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
|
|
|
|
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
|
2021-04-20 22:29:40 -04:00
|
|
|
kernel = triton.ops._matmul.kernel
|
2022-09-18 08:51:48 -07:00
|
|
|
kernel.configs = configs
|
|
|
|
# kernel.run = kernel.run.run.run
|
|
|
|
|
2021-04-20 22:29:40 -04:00
|
|
|
# get matrix shape
|
|
|
|
M = BLOCK_M if M is None else M
|
|
|
|
N = BLOCK_N if N is None else N
|
|
|
|
K = BLOCK_K * SPLIT_K if K is None else K
|
|
|
|
# allocate/transpose inputs
|
2022-01-14 15:38:32 +08:00
|
|
|
DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE]
|
2022-01-06 14:34:17 -08:00
|
|
|
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
|
|
|
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
2021-02-08 12:16:41 -08:00
|
|
|
a = a.t() if AT else a
|
|
|
|
b = b.t() if BT else b
|
2021-04-20 22:29:40 -04:00
|
|
|
# run test
|
2021-02-08 12:16:41 -08:00
|
|
|
th_c = torch.matmul(a, b)
|
2022-01-06 14:34:17 -08:00
|
|
|
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
|
2021-08-12 12:00:30 -07:00
|
|
|
triton.testing.assert_almost_equal(th_c, tt_c)
|