Better NVIDIA Pascal GPU Support (#827)

This PR clarifies which features are supported on P100 via its tests,
though Pascal is not officially and fully supported by Triton.

## What this PR does

- Skip unsupported tests on P100.
  - Atomic RMW
- `tl.dot()` (perhaps not all patterns, but basically most `tl.dot()`
tests do not work on P100).
- Add an explicit error if shared memory size >= 64K on P100.
- Otherwise it causes `Invalid CUDA argument` error at
`cuLaunchKernel()`, but this error is not very straightforward to
understand. Instead of this generic CUDA argument error, this PR makes
Triton show an error during codegen when `sm < 70`. This check happens
in C/C++ so won't add an overhead in Triton's Python runtime.
- 3 tests (see below) are currently failing, but these are not marked as
skipped because any codegen update in the future can change the kernel
size of the other tests.
- This change won't affect Triton-MLIR. Hopefully Triton-MLIR's generic
`tl.dot()` implementation would support P100.

Importantly, Triton passed all the other tests on P100. Though this
support is not official, it is great for, for example, PyTorch's
TorchDynamo/Inductor, which can use Triton (without `tl.dot()`) for its
backend (https://github.com/pytorch/torchdynamo/issues/1591).

### Results on P100 (Google Cloud)

```sh
$ pytest test/unit
...
================================================================================== short test summary info ==================================================================================
FAILED test/unit/language/test_core.py::test_reduce2d[argmin-float32-shape99-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_reduce2d[argmax-float32-shape113-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_permute[float32-shape5-perm5] - RuntimeError: Device does not support shared memory of 67584bytes
================================================================== 3 failed, 3824 passed, 952 skipped in 470.90s (0:07:50) ==================================================================
```

<details><summary> <b>Environment Details (collapsed)</b></summary>
<p>

### VM details (Google Cloud)
https://cloud.google.com/
```
# You need a paid account (free trial does not cover GPUs)
Google Cloud -> New Project -> Compute-Engine -> VM Instance
Machine:
GPU: NVIDIA Tesla P100 x 1
CPU: 2 vCPUs, 7.5GB memory
Boot disk:
  OS: Ubuntu 18.04 LTS
  Disk: 40GB (cannot build Triton on the default 10GB disk)
- When I tried, about $1.2 per hour.
- US instances were full when I tried.  I used Asia or Australia.
- Needed a paid account (GPU is not covered by free trial)
- Needed quota request for any GPU instance (by default, no GPU instance is allowed).  Needed to wait an hour for approval
```

### Reproducer
```sh
## 1. Install CUDA and a driver
# Update the apt key (https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/)
sudo apt-key del 7fa2af80
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb
# Download CUDA as instructed
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin
sudo mv cuda-ubuntu1804.pin /etc/apt/preferences.d/cuda-repository-pin-600
sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/ /"
sudo apt-get update
sudo apt-get -y install cuda
# Are you using P100?
nvidia-smi | grep "Tesla P100"

## 2. Setup the build environment
sudo apt update
sudo apt install -y build-essential wget git libz-dev
wget https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh
bash Anaconda3-2022.05-Linux-x86_64.sh -b -p $(pwd)/anaconda3
eval "$($(pwd)/anaconda3/bin/conda shell.bash hook)"
conda create -y --name triton_base
conda activate triton_base
conda install -y cmake setuptools

## 3. Build Triton
git clone https://github.com/openai/triton.git
cd triton/python
pip3 install -e '.[tests]'

## 4. Test
pytest test/unit
```

### Environment
```sh
$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  On   | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    25W / 250W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
```

</p></details>
This commit is contained in:
Shintaro Iwasaki
2022-11-03 00:11:52 -07:00
committed by GitHub
parent f16138d447
commit 77bc5187b5
4 changed files with 24 additions and 0 deletions

View File

@@ -149,6 +149,13 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
// ir.print(std::cout);
isel.visit(ir, *llvm);
shared_static = allocation.allocated_size();
if (target->as_nvidia() && target->as_nvidia()->sm() < 70) {
// sm < 70 (Pascal) has little shared memory resource.
// Instead of having "Error: Invalid argument" on launching a kernel, let's throw an error here.
if (shared_static >= 65536) {
throw std::runtime_error("Device does not support shared memory of " + std::to_string(shared_static) + "bytes");
}
}
if (isel.get_extern_lib_map().size() > 0) {
// If there's any extern lib calls,

View File

@@ -605,6 +605,10 @@ def test_tuples():
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
if dtype_x_str == 'float16':
pytest.skip("Only test atomic float16 ops on devices with sm >= 70")
n_programs = 5
# triton kernel
@@ -1042,6 +1046,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
if not (allow_tf32 and (dtype in ['float16']))])
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if cc < 80:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
@@ -1227,6 +1233,10 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
M = 32

View File

@@ -2,6 +2,7 @@ import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@@ -125,6 +126,10 @@ def test_attention_fwd_bwd(
batch_size=2,
n_heads=2,
):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
# inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [

View File

@@ -68,6 +68,8 @@ import triton._C.libtriton.triton as _triton
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if cc < 80 and DTYPE == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
if DTYPE == "bfloat16" and SPLIT_K != 1: