[DOCS] fix tutorials for v2.0 (#422)
- Fix meta-parameter usage on tutorials. - Install tutorial dependencies on CI. - Switch from `requirements-test.txt` to `extras_require` for test dependencies, and also use it for tutorial dependencies. - Make some performance tests deterministic.
This commit is contained in:
committed by
GitHub
parent
8bf551ae7a
commit
9801aa7b56
@@ -133,7 +133,7 @@ torch.manual_seed(0)
|
||||
x = torch.randn(1823, 781, device='cuda')
|
||||
y_triton = softmax(x)
|
||||
y_torch = torch.softmax(x, axis=1)
|
||||
print(torch.allclose(y_triton, y_torch))
|
||||
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
|
||||
|
||||
# %%
|
||||
# As expected, the results are identical.
|
||||
|
@@ -237,9 +237,9 @@ def matmul_kernel(
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
# you can fuse arbitrary activation functions here
|
||||
# while the accumulator is still in FP32 !
|
||||
if meta['ACTIVATION']:
|
||||
accumulator = meta['ACTIVATION'](accumulator)
|
||||
# while the accumulator is still in FP32!
|
||||
if ACTIVATION:
|
||||
accumulator = ACTIVATION(accumulator)
|
||||
c = accumulator.to(tl.float16)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
|
@@ -42,9 +42,8 @@ def _dropout(
|
||||
output_ptr, # pointer to the output
|
||||
n_elements, # number of elements in the `x` tensor
|
||||
p, # probability that an element of `x` is changed to zero
|
||||
**meta,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE = meta['BLOCK_SIZE']
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
@@ -108,10 +107,9 @@ def _seeded_dropout(
|
||||
n_elements,
|
||||
p,
|
||||
seed,
|
||||
**meta,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# compute memory offsets of elements handled by this instance
|
||||
BLOCK_SIZE = meta['BLOCK_SIZE']
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
|
@@ -8,11 +8,19 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
try:
|
||||
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
|
||||
# should not be added to extras_require in setup.py.
|
||||
import apex
|
||||
HAS_APEX = True
|
||||
except ModuleNotFoundError:
|
||||
HAS_APEX = False
|
||||
|
||||
|
||||
# Forward Pass
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
# position of elements processed by this program
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE)
|
||||
@@ -42,11 +50,8 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
|
||||
|
||||
# Backward pass (DX + partial DW + partial DB)
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
|
||||
stride, N, eps,
|
||||
**META):
|
||||
GROUP_SIZE_M = META['GROUP_SIZE_M']
|
||||
BLOCK_SIZE_N = META['BLOCK_SIZE_N']
|
||||
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
||||
# position of elements processed by this program
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
||||
@@ -102,15 +107,14 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
|
||||
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
|
||||
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, meta['BLOCK_SIZE_M'])
|
||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
dw += tl.load(DW + offs, mask=mask, other=0.)
|
||||
@@ -216,8 +220,8 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
x_names=['N'],
|
||||
x_vals=[512 * i for i in range(2, 32)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton', 'torch', 'apex'],
|
||||
line_names=['Triton', 'Torch', 'Apex'],
|
||||
line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),
|
||||
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
|
||||
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
|
||||
ylabel='GB/s',
|
||||
plot_name='layer-norm-backward',
|
||||
@@ -239,7 +243,6 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
|
||||
if provider == 'torch':
|
||||
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
|
||||
if provider == 'apex':
|
||||
import apex
|
||||
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
|
||||
y_fwd = lambda: apex_layer_norm(x)
|
||||
# forward pass
|
||||
|
@@ -1,4 +1,11 @@
|
||||
Tutorials
|
||||
==================
|
||||
|
||||
Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one.
|
||||
Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one.
|
||||
|
||||
To install the dependencies for the tutorials:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd triton
|
||||
pip install -e './python[tutorials]'
|
||||
|
Reference in New Issue
Block a user