[DOCS] fix tutorials for v2.0 (#422)
- Fix meta-parameter usage on tutorials. - Install tutorial dependencies on CI. - Switch from `requirements-test.txt` to `extras_require` for test dependencies, and also use it for tutorial dependencies. - Make some performance tests deterministic.
This commit is contained in:
committed by
GitHub
parent
8bf551ae7a
commit
9801aa7b56
4
.github/workflows/documentation.yml
vendored
4
.github/workflows/documentation.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
alias python='python3'
|
alias python='python3'
|
||||||
cd python
|
cd python
|
||||||
pip3 install -e .
|
pip3 install -e '.[tutorials]'
|
||||||
|
|
||||||
- name: Build docs
|
- name: Build docs
|
||||||
run: |
|
run: |
|
||||||
@@ -39,4 +39,4 @@ jobs:
|
|||||||
eval `ssh-agent -s`
|
eval `ssh-agent -s`
|
||||||
DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }}
|
DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }}
|
||||||
git remote set-url origin git@github.com:openai/triton.git
|
git remote set-url origin git@github.com:openai/triton.git
|
||||||
git push
|
git push
|
||||||
|
3
.github/workflows/integration-tests.yml
vendored
3
.github/workflows/integration-tests.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
alias python='python3'
|
alias python='python3'
|
||||||
cd python
|
cd python
|
||||||
pip3 install -e .
|
pip3 install -e '.[tests]'
|
||||||
|
|
||||||
- name: Unit tests
|
- name: Unit tests
|
||||||
run: |
|
run: |
|
||||||
@@ -44,4 +44,3 @@ jobs:
|
|||||||
pytest -vs .
|
pytest -vs .
|
||||||
sudo nvidia-smi -i 0 -rgc
|
sudo nvidia-smi -i 0 -rgc
|
||||||
sudo nvidia-smi -i 0 -rmc
|
sudo nvidia-smi -i 0 -rmc
|
||||||
|
|
||||||
|
@@ -44,7 +44,7 @@ You can then test your installation by running the unit tests:
|
|||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
pip install -r requirements-test.txt
|
pip install -e '.[tests]'
|
||||||
pytest -vs test/unit/
|
pytest -vs test/unit/
|
||||||
|
|
||||||
and the benchmarks
|
and the benchmarks
|
||||||
|
@@ -1,3 +0,0 @@
|
|||||||
numpy
|
|
||||||
pytest
|
|
||||||
scipy >= 1.7.1
|
|
@@ -126,7 +126,11 @@ setup(
|
|||||||
description="A language and compiler for custom Deep Learning operations",
|
description="A language and compiler for custom Deep Learning operations",
|
||||||
long_description="",
|
long_description="",
|
||||||
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
|
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
|
||||||
install_requires=["cmake", "torch", "filelock"],
|
install_requires=[
|
||||||
|
"cmake",
|
||||||
|
"filelock",
|
||||||
|
"torch",
|
||||||
|
],
|
||||||
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
||||||
@@ -142,4 +146,16 @@ setup(
|
|||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
"Programming Language :: Python :: 3.6",
|
"Programming Language :: Python :: 3.6",
|
||||||
],
|
],
|
||||||
|
extras_require={
|
||||||
|
"tests": [
|
||||||
|
"numpy",
|
||||||
|
"pytest",
|
||||||
|
"scipy>=1.7.1",
|
||||||
|
],
|
||||||
|
"tutorials": [
|
||||||
|
"matplotlib",
|
||||||
|
"pandas",
|
||||||
|
"tabulate",
|
||||||
|
],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
@@ -54,6 +54,7 @@ matmul_data = {
|
|||||||
|
|
||||||
@pytest.mark.parametrize('M, N, K', matmul_data.keys())
|
@pytest.mark.parametrize('M, N, K', matmul_data.keys())
|
||||||
def test_matmul(M, N, K):
|
def test_matmul(M, N, K):
|
||||||
|
torch.manual_seed(0)
|
||||||
ref_gpu_util = matmul_data[(M, N, K)]['v100']
|
ref_gpu_util = matmul_data[(M, N, K)]['v100']
|
||||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||||
ref_sm_clock = 1350
|
ref_sm_clock = 1350
|
||||||
@@ -99,6 +100,7 @@ elementwise_data = {
|
|||||||
|
|
||||||
@pytest.mark.parametrize('N', elementwise_data.keys())
|
@pytest.mark.parametrize('N', elementwise_data.keys())
|
||||||
def test_elementwise(N):
|
def test_elementwise(N):
|
||||||
|
torch.manual_seed(0)
|
||||||
ref_gpu_util = elementwise_data[N]['v100']
|
ref_gpu_util = elementwise_data[N]['v100']
|
||||||
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
|
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
|
||||||
ref_mem_clock = 877
|
ref_mem_clock = 877
|
||||||
|
@@ -133,7 +133,7 @@ torch.manual_seed(0)
|
|||||||
x = torch.randn(1823, 781, device='cuda')
|
x = torch.randn(1823, 781, device='cuda')
|
||||||
y_triton = softmax(x)
|
y_triton = softmax(x)
|
||||||
y_torch = torch.softmax(x, axis=1)
|
y_torch = torch.softmax(x, axis=1)
|
||||||
print(torch.allclose(y_triton, y_torch))
|
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# As expected, the results are identical.
|
# As expected, the results are identical.
|
||||||
|
@@ -237,9 +237,9 @@ def matmul_kernel(
|
|||||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
# you can fuse arbitrary activation functions here
|
# you can fuse arbitrary activation functions here
|
||||||
# while the accumulator is still in FP32 !
|
# while the accumulator is still in FP32!
|
||||||
if meta['ACTIVATION']:
|
if ACTIVATION:
|
||||||
accumulator = meta['ACTIVATION'](accumulator)
|
accumulator = ACTIVATION(accumulator)
|
||||||
c = accumulator.to(tl.float16)
|
c = accumulator.to(tl.float16)
|
||||||
|
|
||||||
# -----------------------------------------------------------
|
# -----------------------------------------------------------
|
||||||
|
@@ -42,9 +42,8 @@ def _dropout(
|
|||||||
output_ptr, # pointer to the output
|
output_ptr, # pointer to the output
|
||||||
n_elements, # number of elements in the `x` tensor
|
n_elements, # number of elements in the `x` tensor
|
||||||
p, # probability that an element of `x` is changed to zero
|
p, # probability that an element of `x` is changed to zero
|
||||||
**meta,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
BLOCK_SIZE = meta['BLOCK_SIZE']
|
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
block_start = pid * BLOCK_SIZE
|
block_start = pid * BLOCK_SIZE
|
||||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
@@ -108,10 +107,9 @@ def _seeded_dropout(
|
|||||||
n_elements,
|
n_elements,
|
||||||
p,
|
p,
|
||||||
seed,
|
seed,
|
||||||
**meta,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
# compute memory offsets of elements handled by this instance
|
# compute memory offsets of elements handled by this instance
|
||||||
BLOCK_SIZE = meta['BLOCK_SIZE']
|
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
block_start = pid * BLOCK_SIZE
|
block_start = pid * BLOCK_SIZE
|
||||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
@@ -8,11 +8,19 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
try:
|
||||||
|
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
|
||||||
|
# should not be added to extras_require in setup.py.
|
||||||
|
import apex
|
||||||
|
HAS_APEX = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
HAS_APEX = False
|
||||||
|
|
||||||
|
|
||||||
# Forward Pass
|
# Forward Pass
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
|
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps,
|
||||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
BLOCK_SIZE: tl.constexpr):
|
||||||
# position of elements processed by this program
|
# position of elements processed by this program
|
||||||
row = tl.program_id(0)
|
row = tl.program_id(0)
|
||||||
cols = tl.arange(0, BLOCK_SIZE)
|
cols = tl.arange(0, BLOCK_SIZE)
|
||||||
@@ -42,11 +50,8 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
|
|||||||
|
|
||||||
# Backward pass (DX + partial DW + partial DB)
|
# Backward pass (DX + partial DW + partial DB)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
|
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps,
|
||||||
stride, N, eps,
|
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
||||||
**META):
|
|
||||||
GROUP_SIZE_M = META['GROUP_SIZE_M']
|
|
||||||
BLOCK_SIZE_N = META['BLOCK_SIZE_N']
|
|
||||||
# position of elements processed by this program
|
# position of elements processed by this program
|
||||||
row = tl.program_id(0)
|
row = tl.program_id(0)
|
||||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
cols = tl.arange(0, BLOCK_SIZE_N)
|
||||||
@@ -102,15 +107,14 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
|
|||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
|
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,
|
||||||
|
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
||||||
pid = tl.program_id(0)
|
pid = tl.program_id(0)
|
||||||
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
|
|
||||||
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
|
||||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
for i in range(0, M, BLOCK_SIZE_M):
|
for i in range(0, M, BLOCK_SIZE_M):
|
||||||
rows = i + tl.arange(0, meta['BLOCK_SIZE_M'])
|
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||||
offs = rows[:, None] * N + cols[None, :]
|
offs = rows[:, None] * N + cols[None, :]
|
||||||
dw += tl.load(DW + offs, mask=mask, other=0.)
|
dw += tl.load(DW + offs, mask=mask, other=0.)
|
||||||
@@ -216,8 +220,8 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
|||||||
x_names=['N'],
|
x_names=['N'],
|
||||||
x_vals=[512 * i for i in range(2, 32)],
|
x_vals=[512 * i for i in range(2, 32)],
|
||||||
line_arg='provider',
|
line_arg='provider',
|
||||||
line_vals=['triton', 'torch', 'apex'],
|
line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),
|
||||||
line_names=['Triton', 'Torch', 'Apex'],
|
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
|
||||||
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
|
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
|
||||||
ylabel='GB/s',
|
ylabel='GB/s',
|
||||||
plot_name='layer-norm-backward',
|
plot_name='layer-norm-backward',
|
||||||
@@ -239,7 +243,6 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
|
|||||||
if provider == 'torch':
|
if provider == 'torch':
|
||||||
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
|
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
|
||||||
if provider == 'apex':
|
if provider == 'apex':
|
||||||
import apex
|
|
||||||
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
|
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
|
||||||
y_fwd = lambda: apex_layer_norm(x)
|
y_fwd = lambda: apex_layer_norm(x)
|
||||||
# forward pass
|
# forward pass
|
||||||
|
@@ -1,4 +1,11 @@
|
|||||||
Tutorials
|
Tutorials
|
||||||
==================
|
==================
|
||||||
|
|
||||||
Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one.
|
Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one.
|
||||||
|
|
||||||
|
To install the dependencies for the tutorials:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd triton
|
||||||
|
pip install -e './python[tutorials]'
|
||||||
|
Reference in New Issue
Block a user