2022-01-06 14:34:17 -08:00
|
|
|
import pytest
|
2021-02-08 12:16:41 -08:00
|
|
|
import torch
|
2022-01-06 14:34:17 -08:00
|
|
|
|
2021-02-08 12:16:41 -08:00
|
|
|
import triton
|
|
|
|
|
2021-03-22 20:03:37 -04:00
|
|
|
|
2021-09-20 17:15:31 -07:00
|
|
|
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
|
|
|
@pytest.mark.parametrize("TRANS_A", [False, True])
|
|
|
|
@pytest.mark.parametrize("TRANS_B", [False, True])
|
|
|
|
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
|
2022-12-21 01:30:50 -08:00
|
|
|
# TODO: float32 fails
|
2021-09-27 18:25:16 -07:00
|
|
|
@pytest.mark.parametrize("DTYPE", [torch.float16])
|
2021-03-22 20:03:37 -04:00
|
|
|
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
|
2022-02-06 18:00:45 -08:00
|
|
|
seed = 0
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
is_sdd = MODE == "sdd"
|
|
|
|
is_dsd = MODE == "dsd"
|
|
|
|
is_dds = MODE == "dds"
|
|
|
|
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
|
|
|
|
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
|
2021-02-08 12:16:41 -08:00
|
|
|
# create inputs
|
2022-02-06 18:00:45 -08:00
|
|
|
# create op
|
|
|
|
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
|
|
|
|
b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
|
|
|
|
c_shape = (Z, H, M, N)
|
2021-02-19 17:46:05 -05:00
|
|
|
shape = {
|
|
|
|
"sdd": (M, N),
|
2022-02-06 18:00:45 -08:00
|
|
|
"dsd": (a_shape[2], a_shape[3]),
|
|
|
|
"dds": (b_shape[2], b_shape[3]),
|
2021-02-19 17:46:05 -05:00
|
|
|
}[MODE]
|
2021-02-08 12:16:41 -08:00
|
|
|
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
2022-02-06 18:00:45 -08:00
|
|
|
layout[1, 2, :] = 0
|
|
|
|
layout[1, :, 1] = 0
|
|
|
|
# create data
|
2022-12-21 01:30:50 -08:00
|
|
|
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1, dtype=DTYPE)
|
|
|
|
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1, dtype=DTYPE)
|
|
|
|
dc_ref, dc_tri = triton.testing.make_pair(c_shape, dtype=DTYPE)
|
2022-02-06 18:00:45 -08:00
|
|
|
# compute [torch]
|
|
|
|
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
|
|
|
|
a_ref = do_mask(a_ref) if is_dsd else a_ref
|
|
|
|
b_ref = do_mask(b_ref) if is_dds else b_ref
|
|
|
|
a_ref.retain_grad()
|
|
|
|
b_ref.retain_grad()
|
|
|
|
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
|
|
|
|
b_ref.transpose(2, 3) if TRANS_B else b_ref)
|
|
|
|
c_ref.backward(dc_ref)
|
|
|
|
c_ref = do_sparsify(c_ref) if is_sdd else c_ref
|
|
|
|
da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
|
|
|
|
db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad
|
2021-02-08 12:16:41 -08:00
|
|
|
# triton result
|
2022-02-06 18:00:45 -08:00
|
|
|
dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
|
|
|
|
a_tri = do_sparsify(a_tri) if is_dsd else a_tri
|
|
|
|
b_tri = do_sparsify(b_tri) if is_dds else b_tri
|
|
|
|
a_tri.retain_grad()
|
|
|
|
b_tri.retain_grad()
|
2021-10-30 00:32:58 -07:00
|
|
|
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
|
2022-02-06 18:00:45 -08:00
|
|
|
c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest)
|
|
|
|
triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest)
|
|
|
|
da_tri = a_tri.grad
|
|
|
|
db_tri = b_tri.grad
|
2021-02-08 12:16:41 -08:00
|
|
|
# compare
|
2022-02-06 18:00:45 -08:00
|
|
|
triton.testing.assert_almost_equal(c_ref, c_tri)
|
|
|
|
triton.testing.assert_almost_equal(da_ref, da_tri)
|
|
|
|
triton.testing.assert_almost_equal(db_ref, db_tri)
|
2021-02-08 12:16:41 -08:00
|
|
|
|
2021-03-22 20:03:37 -04:00
|
|
|
|
2022-02-06 18:00:45 -08:00
|
|
|
configs = [
|
|
|
|
(16, 256),
|
|
|
|
(32, 576),
|
|
|
|
(64, 1871),
|
|
|
|
(128, 2511),
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("is_dense", [False, True])
|
|
|
|
@pytest.mark.parametrize("BLOCK, WIDTH", configs)
|
|
|
|
def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
|
2021-02-08 12:16:41 -08:00
|
|
|
# set seed
|
|
|
|
torch.random.manual_seed(0)
|
2022-02-06 18:00:45 -08:00
|
|
|
Z, H, M, N = 2, 3, WIDTH, WIDTH
|
|
|
|
# initialize layout
|
2021-09-27 18:25:16 -07:00
|
|
|
# make sure each row has at least one non-zero element
|
2022-02-06 18:00:45 -08:00
|
|
|
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
|
|
|
if is_dense:
|
|
|
|
layout[:] = 1
|
|
|
|
else:
|
|
|
|
layout[1, 2, :] = 0
|
|
|
|
layout[1, :, 1] = 0
|
|
|
|
# initialize data
|
|
|
|
a_shape = (Z, H, M, N)
|
|
|
|
a_ref, a_tri = triton.testing.make_pair(a_shape)
|
|
|
|
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
|
|
|
|
# compute [torch]
|
|
|
|
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
|
|
|
|
a_ref.retain_grad()
|
|
|
|
at_mask = torch.ones((M, N), device="cuda")
|
2022-01-06 14:34:17 -08:00
|
|
|
if is_causal:
|
|
|
|
at_mask = torch.tril(at_mask)
|
2022-02-06 18:00:45 -08:00
|
|
|
M = at_mask[None, None, :, :] + torch.zeros_like(a_ref)
|
|
|
|
a_ref[M == 0] = float("-inf")
|
|
|
|
out_ref = torch.softmax(a_ref * scale, -1)
|
|
|
|
out_ref.backward(dout_ref)
|
|
|
|
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
|
|
|
|
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
|
|
|
|
# compute [triton]
|
|
|
|
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
|
|
|
|
a_tri.retain_grad()
|
|
|
|
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
|
|
|
|
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
|
|
|
|
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
|
|
|
|
out_tri.backward(dout_tri)
|
|
|
|
da_tri = a_tri.grad
|
2021-02-08 12:16:41 -08:00
|
|
|
# compare
|
2022-02-06 18:00:45 -08:00
|
|
|
triton.testing.assert_almost_equal(out_tri, out_ref)
|
|
|
|
triton.testing.assert_almost_equal(da_tri, da_ref)
|
2021-02-19 17:46:05 -05:00
|
|
|
|
2021-03-22 20:03:37 -04:00
|
|
|
|
2021-09-20 17:15:31 -07:00
|
|
|
@pytest.mark.parametrize("block", [16, 32, 64])
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
2021-02-19 17:46:05 -05:00
|
|
|
def test_attention_fwd_bwd(
|
2021-09-20 17:15:31 -07:00
|
|
|
block,
|
|
|
|
dtype,
|
2021-02-19 17:46:05 -05:00
|
|
|
input_scale=1.0,
|
|
|
|
scale=1 / 8.0,
|
|
|
|
n_ctx=256,
|
|
|
|
batch_size=2,
|
|
|
|
n_heads=2,
|
|
|
|
):
|
2022-12-21 01:30:50 -08:00
|
|
|
capability = torch.cuda.get_device_capability()
|
|
|
|
if capability[0] < 7:
|
Better NVIDIA Pascal GPU Support (#827)
This PR clarifies which features are supported on P100 via its tests,
though Pascal is not officially and fully supported by Triton.
## What this PR does
- Skip unsupported tests on P100.
- Atomic RMW
- `tl.dot()` (perhaps not all patterns, but basically most `tl.dot()`
tests do not work on P100).
- Add an explicit error if shared memory size >= 64K on P100.
- Otherwise it causes `Invalid CUDA argument` error at
`cuLaunchKernel()`, but this error is not very straightforward to
understand. Instead of this generic CUDA argument error, this PR makes
Triton show an error during codegen when `sm < 70`. This check happens
in C/C++ so won't add an overhead in Triton's Python runtime.
- 3 tests (see below) are currently failing, but these are not marked as
skipped because any codegen update in the future can change the kernel
size of the other tests.
- This change won't affect Triton-MLIR. Hopefully Triton-MLIR's generic
`tl.dot()` implementation would support P100.
Importantly, Triton passed all the other tests on P100. Though this
support is not official, it is great for, for example, PyTorch's
TorchDynamo/Inductor, which can use Triton (without `tl.dot()`) for its
backend (https://github.com/pytorch/torchdynamo/issues/1591).
### Results on P100 (Google Cloud)
```sh
$ pytest test/unit
...
================================================================================== short test summary info ==================================================================================
FAILED test/unit/language/test_core.py::test_reduce2d[argmin-float32-shape99-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_reduce2d[argmax-float32-shape113-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_permute[float32-shape5-perm5] - RuntimeError: Device does not support shared memory of 67584bytes
================================================================== 3 failed, 3824 passed, 952 skipped in 470.90s (0:07:50) ==================================================================
```
<details><summary> <b>Environment Details (collapsed)</b></summary>
<p>
### VM details (Google Cloud)
https://cloud.google.com/
```
# You need a paid account (free trial does not cover GPUs)
Google Cloud -> New Project -> Compute-Engine -> VM Instance
Machine:
GPU: NVIDIA Tesla P100 x 1
CPU: 2 vCPUs, 7.5GB memory
Boot disk:
OS: Ubuntu 18.04 LTS
Disk: 40GB (cannot build Triton on the default 10GB disk)
- When I tried, about $1.2 per hour.
- US instances were full when I tried. I used Asia or Australia.
- Needed a paid account (GPU is not covered by free trial)
- Needed quota request for any GPU instance (by default, no GPU instance is allowed). Needed to wait an hour for approval
```
### Reproducer
```sh
## 1. Install CUDA and a driver
# Update the apt key (https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/)
sudo apt-key del 7fa2af80
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb
# Download CUDA as instructed
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin
sudo mv cuda-ubuntu1804.pin /etc/apt/preferences.d/cuda-repository-pin-600
sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/ /"
sudo apt-get update
sudo apt-get -y install cuda
# Are you using P100?
nvidia-smi | grep "Tesla P100"
## 2. Setup the build environment
sudo apt update
sudo apt install -y build-essential wget git libz-dev
wget https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh
bash Anaconda3-2022.05-Linux-x86_64.sh -b -p $(pwd)/anaconda3
eval "$($(pwd)/anaconda3/bin/conda shell.bash hook)"
conda create -y --name triton_base
conda activate triton_base
conda install -y cmake setuptools
## 3. Build Triton
git clone https://github.com/openai/triton.git
cd triton/python
pip3 install -e '.[tests]'
## 4. Test
pytest test/unit
```
### Environment
```sh
$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05 Driver Version: 520.61.05 CUDA Version: 11.8 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla P100-PCIE... On | 00000000:00:04.0 Off | 0 |
| N/A 36C P0 25W / 250W | 0MiB / 16384MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
```
</p></details>
2022-11-03 00:11:52 -07:00
|
|
|
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
|
|
|
|
2021-02-19 17:46:05 -05:00
|
|
|
# inputs
|
|
|
|
qkv_shape = (batch_size, n_heads, n_ctx, 64)
|
2021-09-20 17:15:31 -07:00
|
|
|
qkvs = [
|
|
|
|
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
|
|
|
|
]
|
2021-02-19 17:46:05 -05:00
|
|
|
|
|
|
|
# Triton:
|
|
|
|
n_blocks = n_ctx // block
|
|
|
|
layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long))
|
|
|
|
query, key, value = [x.clone() for x in qkvs]
|
|
|
|
query.retain_grad()
|
|
|
|
key.retain_grad()
|
|
|
|
value.retain_grad()
|
2022-02-06 18:00:45 -08:00
|
|
|
attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale)
|
2021-02-19 17:46:05 -05:00
|
|
|
# ad hoc loss
|
2021-09-20 17:15:31 -07:00
|
|
|
loss = (attn_out ** 2).mean()
|
2021-02-19 17:46:05 -05:00
|
|
|
loss.backward()
|
|
|
|
grads = [query.grad, key.grad, value.grad]
|
|
|
|
|
|
|
|
# Torch version:
|
|
|
|
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
|
2022-02-06 18:00:45 -08:00
|
|
|
attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype)
|
|
|
|
attn_mask = torch.tril(attn_mask, diagonal=0)
|
2021-02-19 17:46:05 -05:00
|
|
|
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
|
|
|
|
torch_q.retain_grad()
|
|
|
|
torch_k.retain_grad()
|
|
|
|
torch_v.retain_grad()
|
|
|
|
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
|
|
|
|
scores = scores + attn_mask
|
|
|
|
probs = torch.softmax(scores, dim=-1)
|
|
|
|
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
|
|
|
|
# ad hoc loss
|
2021-09-20 17:15:31 -07:00
|
|
|
torch_loss = (torch_attn_out ** 2).mean()
|
2021-02-19 17:46:05 -05:00
|
|
|
torch_loss.backward()
|
|
|
|
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
|
|
|
|
|
|
|
|
# comparison
|
2021-02-21 15:19:39 -08:00
|
|
|
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
|
2021-08-12 12:00:30 -07:00
|
|
|
triton.testing.assert_almost_equal(loss, torch_loss)
|
2021-02-19 17:46:05 -05:00
|
|
|
for g1, g2 in zip(grads, torch_grads):
|
2021-08-12 12:00:30 -07:00
|
|
|
triton.testing.assert_almost_equal(g1, g2)
|
2021-02-19 17:46:05 -05:00
|
|
|
|
2021-03-22 20:03:37 -04:00
|
|
|
|
2021-09-20 17:15:31 -07:00
|
|
|
@pytest.mark.parametrize("block", [16, 32, 64])
|
2021-02-19 17:46:05 -05:00
|
|
|
def triton_attention(
|
|
|
|
layout,
|
|
|
|
block: int,
|
|
|
|
query: torch.Tensor,
|
|
|
|
key: torch.Tensor,
|
|
|
|
value: torch.Tensor,
|
|
|
|
scale: float,
|
|
|
|
):
|
2021-10-30 00:32:58 -07:00
|
|
|
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
|
|
|
|
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
|
2022-02-06 18:00:45 -08:00
|
|
|
sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
|
2021-02-19 17:46:05 -05:00
|
|
|
|
|
|
|
w = sparse_dot_sdd_nt(query, key)
|
2022-02-06 18:00:45 -08:00
|
|
|
w = sparse_softmax(w, scale=scale, is_causal=True)
|
2021-02-19 17:46:05 -05:00
|
|
|
a = sparse_dot_dsd_nn(w, value)
|
|
|
|
return a
|