[GH-PAGES] Updated website
@@ -1,4 +1,4 @@
|
||||
# Sphinx build info version 1
|
||||
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
|
||||
config: 255e5b7e27427f0ab6fda308ee6aef63
|
||||
config: 8d52c5eda79abb41e578ed40b306519c
|
||||
tags: 645f666f9bcd5a90fca523b33c5a78b7
|
||||
|
@@ -0,0 +1,97 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Libdevice function\nTriton can invoke a custom function from an external library.\nIn this example, we will use the `libdevice` library to apply `asin` on a tensor.\nPlease refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.\n\nIn `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.\nFor example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.\nUsing triton, you can simply call `tl.libdevice.asinf`.\ntriton automatically selects the correct underlying device function to invoke based on input and output types.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## asin Kernel\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef asin_kernel(\n x_ptr,\n y_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n x = tl.libdevice.asin(x)\n tl.store(y_ptr + offsets, x, mask=mask)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using the default libdevice library path\nWe can use the default libdevice library path encoded in `triton/language/libdevice.py`\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\noutput_triton = torch.zeros(size, device='cuda')\noutput_torch = torch.asin(x)\nassert x.is_cuda and output_triton.is_cuda\nn_elements = output_torch.numel()\ngrid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Customize the libdevice library path\nWe can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"output_triton = torch.empty_like(x)\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,\n extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Libdevice function
|
||||
===============
|
||||
Triton can invoke a custom function from an external library.
|
||||
In this example, we will use the `libdevice` library to apply `asin` on a tensor.
|
||||
Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.
|
||||
|
||||
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
|
||||
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
|
||||
Using triton, you can simply call `tl.libdevice.asinf`.
|
||||
triton automatically selects the correct underlying device function to invoke based on input and output types.
|
||||
"""
|
||||
|
||||
# %%
|
||||
# asin Kernel
|
||||
# --------------------------
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def asin_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
x = tl.libdevice.asin(x)
|
||||
tl.store(y_ptr + offsets, x, mask=mask)
|
||||
|
||||
# %%
|
||||
# Using the default libdevice library path
|
||||
# --------------------------
|
||||
# We can use the default libdevice library path encoded in `triton/language/libdevice.py`
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
size = 98432
|
||||
x = torch.rand(size, device='cuda')
|
||||
output_triton = torch.zeros(size, device='cuda')
|
||||
output_torch = torch.asin(x)
|
||||
assert x.is_cuda and output_triton.is_cuda
|
||||
n_elements = output_torch.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
|
||||
# %%
|
||||
# Customize the libdevice library path
|
||||
# --------------------------
|
||||
# We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.
|
||||
|
||||
output_triton = torch.empty_like(x)
|
||||
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,
|
||||
extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
@@ -0,0 +1,354 @@
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
|
||||
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
# Initialize pointers to Q, K, V
|
||||
q_ptrs = Q + off_q
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
# initialize pointer to m and l
|
||||
t_ptrs = TMP + off_hz * N_CTX + offs_m
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(q_ptrs)
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k, trans_b=True)
|
||||
qk *= sm_scale
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
tl.store(t_ptrs, acc_scale)
|
||||
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + start_n * stride_vk)
|
||||
p = p.to(tl.float16)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# rematerialize offsets to save registers
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_i)
|
||||
tl.store(m_ptrs, m_i)
|
||||
# initialize pointers to output
|
||||
offs_n = tl.arange(0, BLOCK_DMODEL)
|
||||
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
denom = tl.load(L + off_m).to(tl.float32)
|
||||
# compute
|
||||
do = do / denom[:, None]
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
# offset pointers for batch/head
|
||||
Q += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_qz + off_h * stride_qh
|
||||
V += off_z * stride_qz + off_h * stride_qh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
DK += off_z * stride_qz + off_h * stride_qh
|
||||
DV += off_z * stride_qz + off_h * stride_qh
|
||||
for start_n in range(0, num_block):
|
||||
lo = start_n * BLOCK_M
|
||||
# initialize row/col offsets
|
||||
offs_qm = lo + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_m = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
m_ptrs = M + off_hz * N_CTX
|
||||
# initialize dv amd dk
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, k, trans_b=True)
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, v, trans_b=True)
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
|
||||
# # compute dq
|
||||
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
||||
dq += tl.dot(ds.to(tl.float16), k)
|
||||
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
||||
# # increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk = q.shape[-1], k.shape[-1]
|
||||
assert Lq == Lk
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
|
||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
tmp, L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=64, num_warps=4,
|
||||
num_stages=1,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.BLOCK = BLOCK
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = 64
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
sm_scale = 0.3
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
for z in range(Z):
|
||||
for h in range(H):
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# triton implementation
|
||||
tri_out = attention(q, k, v, sm_scale)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
HAS_FLASH = True
|
||||
except BaseException:
|
||||
HAS_FLASH = False
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 16)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||
) for mode in ['bwd']]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
|
||||
assert mode in ['fwd', 'bwd']
|
||||
warmup = 25
|
||||
rep = 100
|
||||
if provider == "triton":
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
sm_scale = 1.3
|
||||
fn = lambda: attention(q, k, v, sm_scale)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
||||
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
# only works on A100 at the moment
|
||||
# bench_flash_attention.run(save_path='.', print_data=True)
|
@@ -128,17 +128,19 @@ def _layer_norm_bwd_dwdb(
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
||||
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw += dout * a_hat
|
||||
db += dout
|
||||
UNROLL: tl.constexpr = 4
|
||||
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
|
||||
for j in range(UNROLL):
|
||||
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
||||
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw += dout * a_hat
|
||||
db += dout
|
||||
sum_dw = tl.sum(dw, axis=0)
|
||||
sum_db = tl.sum(db, axis=0)
|
||||
tl.store(DW + cols, sum_dw, mask=cols < N)
|
||||
@@ -211,7 +213,15 @@ class LayerNorm(torch.autograd.Function):
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
num_warps=ctx.num_warps,
|
||||
)
|
||||
# accumulate partial sums in separate kernel
|
||||
if N > 10240:
|
||||
BLOCK_SIZE_N = 128
|
||||
BLOCK_SIZE_M = 32
|
||||
num_warps = 4
|
||||
else:
|
||||
# maximize occupancy for small N
|
||||
BLOCK_SIZE_N = 16
|
||||
BLOCK_SIZE_M = 16
|
||||
num_warps = 8
|
||||
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||||
_layer_norm_bwd_dwdb[grid](
|
||||
a, dout,
|
||||
@@ -220,17 +230,11 @@ class LayerNorm(torch.autograd.Function):
|
||||
dbias,
|
||||
M,
|
||||
N,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=128,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
num_warps=num_warps
|
||||
)
|
||||
return (da, None, dweight, dbias, None, None,
|
||||
None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None, None, None,
|
||||
None, None, None)
|
||||
return (da, None, dweight, dbias, None)
|
||||
|
||||
|
||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||
|
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 24 KiB |
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
Before Width: | Height: | Size: 37 KiB After Width: | Height: | Size: 36 KiB |
Before Width: | Height: | Size: 23 KiB After Width: | Height: | Size: 23 KiB |
Before Width: | Height: | Size: 59 KiB After Width: | Height: | Size: 59 KiB |
Before Width: | Height: | Size: 34 KiB After Width: | Height: | Size: 34 KiB |
Before Width: | Height: | Size: 36 KiB After Width: | Height: | Size: 35 KiB |
Before Width: | Height: | Size: 22 KiB After Width: | Height: | Size: 22 KiB |
BIN
master/_images/sphx_glr_06-fused-attention_thumb.png
Normal file
After Width: | Height: | Size: 26 KiB |
BIN
master/_images/sphx_glr_07-libdevice-function_thumb.png
Normal file
After Width: | Height: | Size: 26 KiB |
@@ -238,7 +238,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
|
||||
3 32768.0 76.800002 76.800002
|
||||
4 65536.0 127.999995 127.999995
|
||||
5 131072.0 219.428568 219.428568
|
||||
6 262144.0 341.333321 384.000001
|
||||
6 262144.0 341.333321 341.333321
|
||||
7 524288.0 472.615390 472.615390
|
||||
8 1048576.0 614.400016 614.400016
|
||||
9 2097152.0 722.823517 722.823517
|
||||
@@ -255,7 +255,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 1 minutes 34.829 seconds)
|
||||
**Total running time of the script:** ( 1 minutes 50.020 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
|
||||
|
@@ -278,17 +278,17 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
|
||||
|
||||
softmax-performance:
|
||||
N Triton Torch (native) Torch (jit)
|
||||
0 256.0 546.133347 512.000001 188.321838
|
||||
1 384.0 614.400016 585.142862 153.600004
|
||||
2 512.0 655.360017 585.142849 154.566038
|
||||
0 256.0 546.133347 546.133347 186.181817
|
||||
1 384.0 614.400016 585.142862 151.703707
|
||||
2 512.0 655.360017 606.814814 154.566038
|
||||
3 640.0 706.206879 640.000002 160.000000
|
||||
4 768.0 722.823517 664.216187 162.754967
|
||||
4 768.0 722.823517 664.216187 163.839992
|
||||
.. ... ... ... ...
|
||||
93 12160.0 812.359066 406.179533 198.936606
|
||||
94 12288.0 812.429770 415.222812 199.298541
|
||||
95 12416.0 812.498981 412.149375 198.954424
|
||||
96 12544.0 812.566838 412.758863 199.209928
|
||||
97 12672.0 811.007961 412.097543 199.264875
|
||||
94 12288.0 812.429770 415.222812 199.096718
|
||||
95 12416.0 812.498981 412.149375 198.854847
|
||||
96 12544.0 810.925276 412.971190 199.012395
|
||||
97 12672.0 811.007961 412.097543 199.167004
|
||||
|
||||
[98 rows x 4 columns]
|
||||
|
||||
@@ -306,7 +306,7 @@ In the above plot, we can see that:
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 3 minutes 18.076 seconds)
|
||||
**Total running time of the script:** ( 3 minutes 32.089 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
|
||||
|
@@ -459,37 +459,37 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
|
||||
|
||||
matmul-performance:
|
||||
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
||||
0 256.0 2.730667 ... 2.978909 2.978909
|
||||
0 256.0 2.730667 ... 3.276800 2.978909
|
||||
1 384.0 7.372800 ... 7.899428 7.899428
|
||||
2 512.0 14.563555 ... 15.420235 15.420235
|
||||
2 512.0 14.563555 ... 16.384000 15.420235
|
||||
3 640.0 22.260869 ... 24.380953 24.380953
|
||||
4 768.0 32.768000 ... 35.389441 34.028308
|
||||
5 896.0 37.971025 ... 40.140799 39.025776
|
||||
5 896.0 39.025776 ... 40.140799 39.025776
|
||||
6 1024.0 49.932191 ... 53.773130 52.428801
|
||||
7 1152.0 45.242181 ... 48.161033 47.396572
|
||||
7 1152.0 45.242181 ... 47.396572 47.396572
|
||||
8 1280.0 51.200001 ... 57.690139 57.690139
|
||||
9 1408.0 64.138541 ... 68.147202 65.684049
|
||||
10 1536.0 79.526831 ... 81.355034 78.643199
|
||||
11 1664.0 63.372618 ... 63.372618 62.492442
|
||||
9 1408.0 64.138541 ... 68.147202 66.485074
|
||||
10 1536.0 80.430545 ... 80.430545 78.643199
|
||||
11 1664.0 62.929456 ... 63.372618 62.492442
|
||||
12 1792.0 72.983276 ... 72.983276 59.154861
|
||||
13 1920.0 68.776119 ... 71.626943 70.892307
|
||||
14 2048.0 73.584279 ... 78.033565 76.959706
|
||||
15 2176.0 83.155572 ... 87.494120 86.367588
|
||||
16 2304.0 68.446623 ... 78.064941 77.057651
|
||||
17 2432.0 71.305746 ... 86.179335 85.393507
|
||||
18 2560.0 77.833728 ... 82.956960 81.715711
|
||||
19 2688.0 83.737433 ... 91.185232 89.464755
|
||||
20 2816.0 82.446516 ... 84.523664 83.712490
|
||||
21 2944.0 81.967162 ... 83.758038 82.373605
|
||||
22 3072.0 82.420822 ... 88.750943 86.579673
|
||||
23 3200.0 81.528664 ... 91.233074 95.665176
|
||||
24 3328.0 83.516586 ... 85.908470 83.323259
|
||||
25 3456.0 81.435930 ... 92.138932 90.180725
|
||||
26 3584.0 83.954614 ... 91.189190 95.858629
|
||||
27 3712.0 85.822459 ... 83.806497 87.783251
|
||||
28 3840.0 80.901241 ... 89.259080 89.548180
|
||||
29 3968.0 87.913500 ... 92.829164 84.096442
|
||||
30 4096.0 93.825748 ... 89.299883 90.139506
|
||||
13 1920.0 69.120002 ... 71.257735 71.257735
|
||||
14 2048.0 73.584279 ... 78.398206 77.314362
|
||||
15 2176.0 83.155572 ... 87.494120 85.998493
|
||||
16 2304.0 68.446623 ... 78.320893 77.558029
|
||||
17 2432.0 71.305746 ... 86.711310 75.421383
|
||||
18 2560.0 77.833728 ... 82.747477 81.715711
|
||||
19 2688.0 83.552988 ... 90.532356 89.464755
|
||||
20 2816.0 84.197315 ... 84.035084 84.035084
|
||||
21 2944.0 82.784108 ... 83.969728 83.060049
|
||||
22 3072.0 81.825298 ... 89.593522 88.473602
|
||||
23 3200.0 84.768213 ... 96.096095 95.808380
|
||||
24 3328.0 83.226931 ... 85.908470 84.596116
|
||||
25 3456.0 81.766291 ... 91.824110 91.097818
|
||||
26 3584.0 87.466332 ... 91.194972 94.847460
|
||||
27 3712.0 85.822459 ... 87.246590 87.860458
|
||||
28 3840.0 81.859361 ... 87.011801 90.168771
|
||||
29 3968.0 89.921841 ... 91.954739 85.271796
|
||||
30 4096.0 93.596744 ... 88.243079 90.382307
|
||||
|
||||
[31 rows x 5 columns]
|
||||
|
||||
@@ -499,7 +499,7 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 5 minutes 52.578 seconds)
|
||||
**Total running time of the script:** ( 7 minutes 13.827 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
|
||||
|
@@ -240,7 +240,7 @@ References
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 0.476 seconds)
|
||||
**Total running time of the script:** ( 0 minutes 0.279 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py:
|
||||
|
@@ -21,7 +21,7 @@
|
||||
Layer Normalization
|
||||
====================
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 5-312
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 5-316
|
||||
|
||||
|
||||
|
||||
@@ -40,34 +40,34 @@ Layer Normalization
|
||||
N Triton Torch Apex
|
||||
0 1024.0 585.142849 277.694907 468.114273
|
||||
1 1536.0 630.153868 323.368435 511.999982
|
||||
2 2048.0 682.666643 334.367358 520.126988
|
||||
3 2560.0 694.237267 365.714281 518.481028
|
||||
4 3072.0 712.347810 378.092307 501.551037
|
||||
5 3584.0 725.873439 384.859062 458.751978
|
||||
6 4096.0 728.177767 381.023256 458.293714
|
||||
7 4608.0 670.254540 396.387087 426.173427
|
||||
8 5120.0 694.237267 397.669909 426.666652
|
||||
9 5632.0 704.000002 396.969169 413.357796
|
||||
10 6144.0 702.171410 402.885254 411.313806
|
||||
2 2048.0 668.734716 337.814445 528.516136
|
||||
3 2560.0 694.237267 362.477870 512.000013
|
||||
4 3072.0 712.347810 375.206126 501.551037
|
||||
5 3584.0 725.873439 384.859062 451.527536
|
||||
6 4096.0 728.177767 381.023256 455.111095
|
||||
7 4608.0 670.254540 396.387087 421.302872
|
||||
8 5120.0 688.403381 395.748783 422.268057
|
||||
9 5632.0 698.542675 396.969169 409.599997
|
||||
10 6144.0 702.171410 402.885254 409.600010
|
||||
11 6656.0 700.631610 400.360920 400.360920
|
||||
12 7168.0 695.078767 396.844306 388.772874
|
||||
13 7680.0 682.666656 393.846167 387.634072
|
||||
14 8192.0 642.509816 393.609605 372.363633
|
||||
15 8704.0 627.315309 389.005597 380.502740
|
||||
16 9216.0 606.814809 407.337026 383.999986
|
||||
17 9728.0 589.575753 409.599987 383.369452
|
||||
18 10240.0 566.920437 408.578556 382.803739
|
||||
19 10752.0 549.623009 411.559798 381.445676
|
||||
20 11264.0 536.380957 406.826188 373.134567
|
||||
21 11776.0 523.377770 410.492372 377.587162
|
||||
22 12288.0 517.389457 414.784810 383.251457
|
||||
23 12800.0 505.679014 410.420828 376.470582
|
||||
24 13312.0 494.180982 405.699062 376.976995
|
||||
25 13824.0 482.934503 411.888257 379.389355
|
||||
26 14336.0 471.967074 406.695045 374.185964
|
||||
27 14848.0 461.297068 408.192434 375.304904
|
||||
28 15360.0 454.269882 406.214870 378.092307
|
||||
29 15872.0 447.887117 407.627589 376.225175
|
||||
12 7168.0 678.627194 386.154893 384.859062
|
||||
13 7680.0 682.666656 391.337574 386.415087
|
||||
14 8192.0 645.674867 390.095241 376.643677
|
||||
15 8704.0 624.502255 390.095225 379.465939
|
||||
16 9216.0 604.327881 405.098894 383.002605
|
||||
17 9728.0 585.142883 409.599987 382.427505
|
||||
18 10240.0 564.965524 409.600010 382.803739
|
||||
19 10752.0 546.133312 410.577576 380.601764
|
||||
20 11264.0 531.634232 395.228063 370.069806
|
||||
21 11776.0 520.486200 409.599991 376.831982
|
||||
22 12288.0 516.031509 413.911572 383.251457
|
||||
23 12800.0 504.433489 410.420828 375.779805
|
||||
24 13312.0 494.180982 405.699062 376.310952
|
||||
25 13824.0 481.882350 411.888257 378.739711
|
||||
26 14336.0 471.967074 401.709294 372.969090
|
||||
27 14848.0 461.297068 407.492270 375.898745
|
||||
28 15360.0 453.431739 406.887417 378.092307
|
||||
29 15872.0 447.098578 406.323209 376.225175
|
||||
|
||||
|
||||
|
||||
@@ -204,17 +204,19 @@ Layer Normalization
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
||||
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw += dout * a_hat
|
||||
db += dout
|
||||
UNROLL: tl.constexpr = 4
|
||||
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
|
||||
for j in range(UNROLL):
|
||||
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
||||
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw += dout * a_hat
|
||||
db += dout
|
||||
sum_dw = tl.sum(dw, axis=0)
|
||||
sum_db = tl.sum(db, axis=0)
|
||||
tl.store(DW + cols, sum_dw, mask=cols < N)
|
||||
@@ -287,7 +289,15 @@ Layer Normalization
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
num_warps=ctx.num_warps,
|
||||
)
|
||||
# accumulate partial sums in separate kernel
|
||||
if N > 10240:
|
||||
BLOCK_SIZE_N = 128
|
||||
BLOCK_SIZE_M = 32
|
||||
num_warps = 4
|
||||
else:
|
||||
# maximize occupancy for small N
|
||||
BLOCK_SIZE_N = 16
|
||||
BLOCK_SIZE_M = 16
|
||||
num_warps = 8
|
||||
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||||
_layer_norm_bwd_dwdb[grid](
|
||||
a, dout,
|
||||
@@ -296,17 +306,11 @@ Layer Normalization
|
||||
dbias,
|
||||
M,
|
||||
N,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=128,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
num_warps=num_warps
|
||||
)
|
||||
return (da, None, dweight, dbias, None, None,
|
||||
None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None, None, None,
|
||||
None, None, None)
|
||||
return (da, None, dweight, dbias, None)
|
||||
|
||||
|
||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||
@@ -389,7 +393,7 @@ Layer Normalization
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 5 minutes 24.641 seconds)
|
||||
**Total running time of the script:** ( 5 minutes 32.552 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py:
|
||||
|
@@ -0,0 +1,416 @@
|
||||
|
||||
.. DO NOT EDIT.
|
||||
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
||||
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
||||
.. "getting-started/tutorials/06-fused-attention.py"
|
||||
.. LINE NUMBERS ARE GIVEN BELOW.
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
||||
Click :ref:`here <sphx_glr_download_getting-started_tutorials_06-fused-attention.py>`
|
||||
to download the full example code
|
||||
|
||||
.. rst-class:: sphx-glr-example-title
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials_06-fused-attention.py:
|
||||
|
||||
|
||||
Fused Attention
|
||||
===============
|
||||
This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 7-355
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
|
||||
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
# Initialize pointers to Q, K, V
|
||||
q_ptrs = Q + off_q
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
# initialize pointer to m and l
|
||||
t_ptrs = TMP + off_hz * N_CTX + offs_m
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(q_ptrs)
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k, trans_b=True)
|
||||
qk *= sm_scale
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
tl.store(t_ptrs, acc_scale)
|
||||
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + start_n * stride_vk)
|
||||
p = p.to(tl.float16)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# rematerialize offsets to save registers
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_i)
|
||||
tl.store(m_ptrs, m_i)
|
||||
# initialize pointers to output
|
||||
offs_n = tl.arange(0, BLOCK_DMODEL)
|
||||
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
denom = tl.load(L + off_m).to(tl.float32)
|
||||
# compute
|
||||
do = do / denom[:, None]
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
# offset pointers for batch/head
|
||||
Q += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_qz + off_h * stride_qh
|
||||
V += off_z * stride_qz + off_h * stride_qh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
DK += off_z * stride_qz + off_h * stride_qh
|
||||
DV += off_z * stride_qz + off_h * stride_qh
|
||||
for start_n in range(0, num_block):
|
||||
lo = start_n * BLOCK_M
|
||||
# initialize row/col offsets
|
||||
offs_qm = lo + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_m = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
m_ptrs = M + off_hz * N_CTX
|
||||
# initialize dv amd dk
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, k, trans_b=True)
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, v, trans_b=True)
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
|
||||
# # compute dq
|
||||
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
||||
dq += tl.dot(ds.to(tl.float16), k)
|
||||
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
||||
# # increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk = q.shape[-1], k.shape[-1]
|
||||
assert Lq == Lk
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
|
||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
tmp, L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=64, num_warps=4,
|
||||
num_stages=1,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.BLOCK = BLOCK
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = 64
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
sm_scale = 0.3
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
for z in range(Z):
|
||||
for h in range(H):
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# triton implementation
|
||||
tri_out = attention(q, k, v, sm_scale)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
HAS_FLASH = True
|
||||
except BaseException:
|
||||
HAS_FLASH = False
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 16)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||
) for mode in ['bwd']]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
|
||||
assert mode in ['fwd', 'bwd']
|
||||
warmup = 25
|
||||
rep = 100
|
||||
if provider == "triton":
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
sm_scale = 1.3
|
||||
fn = lambda: attention(q, k, v, sm_scale)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
||||
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
# only works on A100 at the moment
|
||||
# bench_flash_attention.run(save_path='.', print_data=True)
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 0.072 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_06-fused-attention.py:
|
||||
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-example
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download Python source code: 06-fused-attention.py <06-fused-attention.py>`
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: 06-fused-attention.ipynb <06-fused-attention.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
@@ -0,0 +1,183 @@
|
||||
|
||||
.. DO NOT EDIT.
|
||||
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
||||
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
||||
.. "getting-started/tutorials/07-libdevice-function.py"
|
||||
.. LINE NUMBERS ARE GIVEN BELOW.
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
||||
Click :ref:`here <sphx_glr_download_getting-started_tutorials_07-libdevice-function.py>`
|
||||
to download the full example code
|
||||
|
||||
.. rst-class:: sphx-glr-example-title
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials_07-libdevice-function.py:
|
||||
|
||||
|
||||
Libdevice function
|
||||
===============
|
||||
Triton can invoke a custom function from an external library.
|
||||
In this example, we will use the `libdevice` library to apply `asin` on a tensor.
|
||||
Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.
|
||||
|
||||
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
|
||||
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
|
||||
Using triton, you can simply call `tl.libdevice.asinf`.
|
||||
triton automatically selects the correct underlying device function to invoke based on input and output types.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 15-17
|
||||
|
||||
asin Kernel
|
||||
--------------------------
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 17-39
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def asin_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
x = tl.libdevice.asin(x)
|
||||
tl.store(y_ptr + offsets, x, mask=mask)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 40-43
|
||||
|
||||
Using the default libdevice library path
|
||||
--------------------------
|
||||
We can use the default libdevice library path encoded in `triton/language/libdevice.py`
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 43-61
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
size = 98432
|
||||
x = torch.rand(size, device='cuda')
|
||||
output_triton = torch.zeros(size, device='cuda')
|
||||
output_torch = torch.asin(x)
|
||||
assert x.is_cuda and output_triton.is_cuda
|
||||
n_elements = output_torch.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
|
||||
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
|
||||
The maximum difference between torch and triton is 2.384185791015625e-07
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 62-65
|
||||
|
||||
Customize the libdevice library path
|
||||
--------------------------
|
||||
We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 65-75
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
output_triton = torch.empty_like(x)
|
||||
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,
|
||||
extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
|
||||
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
|
||||
The maximum difference between torch and triton is 2.384185791015625e-07
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 0.501 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_07-libdevice-function.py:
|
||||
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-example
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download Python source code: 07-libdevice-function.py <07-libdevice-function.py>`
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: 07-libdevice-function.ipynb <07-libdevice-function.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
@@ -122,6 +122,48 @@ To install the dependencies for the tutorials:
|
||||
:hidden:
|
||||
|
||||
/getting-started/tutorials/05-layer-norm
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Fused Attention">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_06-fused-attention_thumb.png
|
||||
:alt: Fused Attention
|
||||
|
||||
:ref:`sphx_glr_getting-started_tutorials_06-fused-attention.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/getting-started/tutorials/06-fused-attention
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="In trition/language/libdevice.py, we try to aggregate functions with the same computation but d...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_07-libdevice-function_thumb.png
|
||||
:alt: Libdevice function
|
||||
|
||||
:ref:`sphx_glr_getting-started_tutorials_07-libdevice-function.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/getting-started/tutorials/07-libdevice-function
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-clear"></div>
|
||||
|
@@ -5,16 +5,20 @@
|
||||
|
||||
Computation times
|
||||
=================
|
||||
**16:10.599** total execution time for **getting-started_tutorials** files:
|
||||
**18:09.339** total execution time for **getting-started_tutorials** files:
|
||||
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 05:52.578 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 07:13.827 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 05:24.641 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 05:32.552 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:18.076 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:32.089 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:34.829 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:50.020 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.476 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_07-libdevice-function.py` (``07-libdevice-function.py``) | 00:00.501 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.279 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_06-fused-attention.py` (``06-fused-attention.py``) | 00:00.072 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
|
@@ -105,6 +105,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
@@ -328,7 +330,7 @@ for different problem sizes.</p>
|
||||
3 32768.0 76.800002 76.800002
|
||||
4 65536.0 127.999995 127.999995
|
||||
5 131072.0 219.428568 219.428568
|
||||
6 262144.0 341.333321 384.000001
|
||||
6 262144.0 341.333321 341.333321
|
||||
7 524288.0 472.615390 472.615390
|
||||
8 1048576.0 614.400016 614.400016
|
||||
9 2097152.0 722.823517 722.823517
|
||||
@@ -340,7 +342,7 @@ for different problem sizes.</p>
|
||||
15 134217728.0 849.737435 850.656574
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 34.829 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 50.020 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>
|
||||
|
@@ -108,6 +108,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
@@ -369,17 +371,17 @@ We will then compare its performance against (1) <code class="code docutils lite
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>softmax-performance:
|
||||
N Triton Torch (native) Torch (jit)
|
||||
0 256.0 546.133347 512.000001 188.321838
|
||||
1 384.0 614.400016 585.142862 153.600004
|
||||
2 512.0 655.360017 585.142849 154.566038
|
||||
0 256.0 546.133347 546.133347 186.181817
|
||||
1 384.0 614.400016 585.142862 151.703707
|
||||
2 512.0 655.360017 606.814814 154.566038
|
||||
3 640.0 706.206879 640.000002 160.000000
|
||||
4 768.0 722.823517 664.216187 162.754967
|
||||
4 768.0 722.823517 664.216187 163.839992
|
||||
.. ... ... ... ...
|
||||
93 12160.0 812.359066 406.179533 198.936606
|
||||
94 12288.0 812.429770 415.222812 199.298541
|
||||
95 12416.0 812.498981 412.149375 198.954424
|
||||
96 12544.0 812.566838 412.758863 199.209928
|
||||
97 12672.0 811.007961 412.097543 199.264875
|
||||
94 12288.0 812.429770 415.222812 199.096718
|
||||
95 12416.0 812.498981 412.149375 198.854847
|
||||
96 12544.0 810.925276 412.971190 199.012395
|
||||
97 12672.0 811.007961 412.097543 199.167004
|
||||
|
||||
[98 rows x 4 columns]
|
||||
</pre></div>
|
||||
@@ -392,7 +394,7 @@ We will then compare its performance against (1) <code class="code docutils lite
|
||||
Note however that the PyTorch <cite>softmax</cite> operation is more general and will works on tensors of any shape.</p></li>
|
||||
</ul>
|
||||
</div></blockquote>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 3 minutes 18.076 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 3 minutes 32.089 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>
|
||||
|
@@ -115,6 +115,8 @@
|
||||
</li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
@@ -565,42 +567,42 @@ torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -3
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>matmul-performance:
|
||||
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
||||
0 256.0 2.730667 ... 2.978909 2.978909
|
||||
0 256.0 2.730667 ... 3.276800 2.978909
|
||||
1 384.0 7.372800 ... 7.899428 7.899428
|
||||
2 512.0 14.563555 ... 15.420235 15.420235
|
||||
2 512.0 14.563555 ... 16.384000 15.420235
|
||||
3 640.0 22.260869 ... 24.380953 24.380953
|
||||
4 768.0 32.768000 ... 35.389441 34.028308
|
||||
5 896.0 37.971025 ... 40.140799 39.025776
|
||||
5 896.0 39.025776 ... 40.140799 39.025776
|
||||
6 1024.0 49.932191 ... 53.773130 52.428801
|
||||
7 1152.0 45.242181 ... 48.161033 47.396572
|
||||
7 1152.0 45.242181 ... 47.396572 47.396572
|
||||
8 1280.0 51.200001 ... 57.690139 57.690139
|
||||
9 1408.0 64.138541 ... 68.147202 65.684049
|
||||
10 1536.0 79.526831 ... 81.355034 78.643199
|
||||
11 1664.0 63.372618 ... 63.372618 62.492442
|
||||
9 1408.0 64.138541 ... 68.147202 66.485074
|
||||
10 1536.0 80.430545 ... 80.430545 78.643199
|
||||
11 1664.0 62.929456 ... 63.372618 62.492442
|
||||
12 1792.0 72.983276 ... 72.983276 59.154861
|
||||
13 1920.0 68.776119 ... 71.626943 70.892307
|
||||
14 2048.0 73.584279 ... 78.033565 76.959706
|
||||
15 2176.0 83.155572 ... 87.494120 86.367588
|
||||
16 2304.0 68.446623 ... 78.064941 77.057651
|
||||
17 2432.0 71.305746 ... 86.179335 85.393507
|
||||
18 2560.0 77.833728 ... 82.956960 81.715711
|
||||
19 2688.0 83.737433 ... 91.185232 89.464755
|
||||
20 2816.0 82.446516 ... 84.523664 83.712490
|
||||
21 2944.0 81.967162 ... 83.758038 82.373605
|
||||
22 3072.0 82.420822 ... 88.750943 86.579673
|
||||
23 3200.0 81.528664 ... 91.233074 95.665176
|
||||
24 3328.0 83.516586 ... 85.908470 83.323259
|
||||
25 3456.0 81.435930 ... 92.138932 90.180725
|
||||
26 3584.0 83.954614 ... 91.189190 95.858629
|
||||
27 3712.0 85.822459 ... 83.806497 87.783251
|
||||
28 3840.0 80.901241 ... 89.259080 89.548180
|
||||
29 3968.0 87.913500 ... 92.829164 84.096442
|
||||
30 4096.0 93.825748 ... 89.299883 90.139506
|
||||
13 1920.0 69.120002 ... 71.257735 71.257735
|
||||
14 2048.0 73.584279 ... 78.398206 77.314362
|
||||
15 2176.0 83.155572 ... 87.494120 85.998493
|
||||
16 2304.0 68.446623 ... 78.320893 77.558029
|
||||
17 2432.0 71.305746 ... 86.711310 75.421383
|
||||
18 2560.0 77.833728 ... 82.747477 81.715711
|
||||
19 2688.0 83.552988 ... 90.532356 89.464755
|
||||
20 2816.0 84.197315 ... 84.035084 84.035084
|
||||
21 2944.0 82.784108 ... 83.969728 83.060049
|
||||
22 3072.0 81.825298 ... 89.593522 88.473602
|
||||
23 3200.0 84.768213 ... 96.096095 95.808380
|
||||
24 3328.0 83.226931 ... 85.908470 84.596116
|
||||
25 3456.0 81.766291 ... 91.824110 91.097818
|
||||
26 3584.0 87.466332 ... 91.194972 94.847460
|
||||
27 3712.0 85.822459 ... 87.246590 87.860458
|
||||
28 3840.0 81.859361 ... 87.011801 90.168771
|
||||
29 3968.0 89.921841 ... 91.954739 85.271796
|
||||
30 4096.0 93.596744 ... 88.243079 90.382307
|
||||
|
||||
[31 rows x 5 columns]
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 5 minutes 52.578 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 7 minutes 13.827 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
|
||||
|
@@ -108,6 +108,8 @@
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
@@ -372,7 +374,7 @@ to explore the <cite>triton/language/random</cite> folder!</p>
|
||||
<dd><p>Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, JMLR 2014</p>
|
||||
</dd>
|
||||
</dl>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 0.476 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 0.279 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/c9aed78977a4c05741d675a38dde3d7d/04-low-memory-dropout.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">04-low-memory-dropout.py</span></code></a></p>
|
||||
|
@@ -46,7 +46,7 @@
|
||||
|
||||
<link rel="index" title="Index" href="../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../search.html" />
|
||||
<link rel="next" title="triton" href="../../python-api/triton.html" />
|
||||
<link rel="next" title="Fused Attention" href="06-fused-attention.html" />
|
||||
<link rel="prev" title="Low-Memory Dropout" href="04-low-memory-dropout.html" />
|
||||
</head>
|
||||
|
||||
@@ -101,6 +101,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
|
||||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Layer Normalization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
@@ -196,34 +198,34 @@ to download the full example code</p>
|
||||
N Triton Torch Apex
|
||||
0 1024.0 585.142849 277.694907 468.114273
|
||||
1 1536.0 630.153868 323.368435 511.999982
|
||||
2 2048.0 682.666643 334.367358 520.126988
|
||||
3 2560.0 694.237267 365.714281 518.481028
|
||||
4 3072.0 712.347810 378.092307 501.551037
|
||||
5 3584.0 725.873439 384.859062 458.751978
|
||||
6 4096.0 728.177767 381.023256 458.293714
|
||||
7 4608.0 670.254540 396.387087 426.173427
|
||||
8 5120.0 694.237267 397.669909 426.666652
|
||||
9 5632.0 704.000002 396.969169 413.357796
|
||||
10 6144.0 702.171410 402.885254 411.313806
|
||||
2 2048.0 668.734716 337.814445 528.516136
|
||||
3 2560.0 694.237267 362.477870 512.000013
|
||||
4 3072.0 712.347810 375.206126 501.551037
|
||||
5 3584.0 725.873439 384.859062 451.527536
|
||||
6 4096.0 728.177767 381.023256 455.111095
|
||||
7 4608.0 670.254540 396.387087 421.302872
|
||||
8 5120.0 688.403381 395.748783 422.268057
|
||||
9 5632.0 698.542675 396.969169 409.599997
|
||||
10 6144.0 702.171410 402.885254 409.600010
|
||||
11 6656.0 700.631610 400.360920 400.360920
|
||||
12 7168.0 695.078767 396.844306 388.772874
|
||||
13 7680.0 682.666656 393.846167 387.634072
|
||||
14 8192.0 642.509816 393.609605 372.363633
|
||||
15 8704.0 627.315309 389.005597 380.502740
|
||||
16 9216.0 606.814809 407.337026 383.999986
|
||||
17 9728.0 589.575753 409.599987 383.369452
|
||||
18 10240.0 566.920437 408.578556 382.803739
|
||||
19 10752.0 549.623009 411.559798 381.445676
|
||||
20 11264.0 536.380957 406.826188 373.134567
|
||||
21 11776.0 523.377770 410.492372 377.587162
|
||||
22 12288.0 517.389457 414.784810 383.251457
|
||||
23 12800.0 505.679014 410.420828 376.470582
|
||||
24 13312.0 494.180982 405.699062 376.976995
|
||||
25 13824.0 482.934503 411.888257 379.389355
|
||||
26 14336.0 471.967074 406.695045 374.185964
|
||||
27 14848.0 461.297068 408.192434 375.304904
|
||||
28 15360.0 454.269882 406.214870 378.092307
|
||||
29 15872.0 447.887117 407.627589 376.225175
|
||||
12 7168.0 678.627194 386.154893 384.859062
|
||||
13 7680.0 682.666656 391.337574 386.415087
|
||||
14 8192.0 645.674867 390.095241 376.643677
|
||||
15 8704.0 624.502255 390.095225 379.465939
|
||||
16 9216.0 604.327881 405.098894 383.002605
|
||||
17 9728.0 585.142883 409.599987 382.427505
|
||||
18 10240.0 564.965524 409.600010 382.803739
|
||||
19 10752.0 546.133312 410.577576 380.601764
|
||||
20 11264.0 531.634232 395.228063 370.069806
|
||||
21 11776.0 520.486200 409.599991 376.831982
|
||||
22 12288.0 516.031509 413.911572 383.251457
|
||||
23 12800.0 504.433489 410.420828 375.779805
|
||||
24 13312.0 494.180982 405.699062 376.310952
|
||||
25 13824.0 481.882350 411.888257 378.739711
|
||||
26 14336.0 471.967074 401.709294 372.969090
|
||||
27 14848.0 461.297068 407.492270 375.898745
|
||||
28 15360.0 453.431739 406.887417 378.092307
|
||||
29 15872.0 447.098578 406.323209 376.225175
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="line-block">
|
||||
@@ -354,17 +356,19 @@ to download the full example code</p>
|
||||
<span class="n">cols</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
|
||||
<span class="n">dw</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">db</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">):</span>
|
||||
<span class="n">rows</span> <span class="o">=</span> <span class="n">i</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o"><</span> <span class="n">M</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||||
<span class="n">offs</span> <span class="o">=</span> <span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
|
||||
<span class="n">a</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">A</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">dout</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DOut</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">mean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Mean</span> <span class="o">+</span> <span class="n">rows</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">rows</span> <span class="o"><</span> <span class="n">M</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
|
||||
<span class="n">rstd</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Var</span> <span class="o">+</span> <span class="n">rows</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">rows</span> <span class="o"><</span> <span class="n">M</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
|
||||
<span class="n">a_hat</span> <span class="o">=</span> <span class="p">(</span><span class="n">a</span> <span class="o">-</span> <span class="n">mean</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span> <span class="o">*</span> <span class="n">rstd</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="n">dw</span> <span class="o">+=</span> <span class="n">dout</span> <span class="o">*</span> <span class="n">a_hat</span>
|
||||
<span class="n">db</span> <span class="o">+=</span> <span class="n">dout</span>
|
||||
<span class="n">UNROLL</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span> <span class="o">=</span> <span class="mi">4</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">*</span> <span class="n">UNROLL</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">UNROLL</span><span class="p">):</span>
|
||||
<span class="n">rows</span> <span class="o">=</span> <span class="n">i</span> <span class="o">+</span> <span class="n">j</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o"><</span> <span class="n">M</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||||
<span class="n">offs</span> <span class="o">=</span> <span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
|
||||
<span class="n">a</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">A</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">dout</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DOut</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">mean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Mean</span> <span class="o">+</span> <span class="n">rows</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">rows</span> <span class="o"><</span> <span class="n">M</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
|
||||
<span class="n">rstd</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Var</span> <span class="o">+</span> <span class="n">rows</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">rows</span> <span class="o"><</span> <span class="n">M</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
|
||||
<span class="n">a_hat</span> <span class="o">=</span> <span class="p">(</span><span class="n">a</span> <span class="o">-</span> <span class="n">mean</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span> <span class="o">*</span> <span class="n">rstd</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="n">dw</span> <span class="o">+=</span> <span class="n">dout</span> <span class="o">*</span> <span class="n">a_hat</span>
|
||||
<span class="n">db</span> <span class="o">+=</span> <span class="n">dout</span>
|
||||
<span class="n">sum_dw</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dw</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">sum_db</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">db</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">DW</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">sum_dw</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">cols</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||||
@@ -437,7 +441,15 @@ to download the full example code</p>
|
||||
<span class="n">BLOCK_SIZE_N</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_SIZE</span><span class="p">,</span>
|
||||
<span class="n">num_warps</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">num_warps</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="c1"># accumulate partial sums in separate kernel</span>
|
||||
<span class="k">if</span> <span class="n">N</span> <span class="o">></span> <span class="mi">10240</span><span class="p">:</span>
|
||||
<span class="n">BLOCK_SIZE_N</span> <span class="o">=</span> <span class="mi">128</span>
|
||||
<span class="n">BLOCK_SIZE_M</span> <span class="o">=</span> <span class="mi">32</span>
|
||||
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">4</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="c1"># maximize occupancy for small N</span>
|
||||
<span class="n">BLOCK_SIZE_N</span> <span class="o">=</span> <span class="mi">16</span>
|
||||
<span class="n">BLOCK_SIZE_M</span> <span class="o">=</span> <span class="mi">16</span>
|
||||
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">8</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">[</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s2">"BLOCK_SIZE_N"</span><span class="p">])]</span>
|
||||
<span class="n">_layer_norm_bwd_dwdb</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
|
||||
<span class="n">a</span><span class="p">,</span> <span class="n">dout</span><span class="p">,</span>
|
||||
@@ -446,17 +458,11 @@ to download the full example code</p>
|
||||
<span class="n">dbias</span><span class="p">,</span>
|
||||
<span class="n">M</span><span class="p">,</span>
|
||||
<span class="n">N</span><span class="p">,</span>
|
||||
<span class="n">BLOCK_SIZE_M</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
|
||||
<span class="n">BLOCK_SIZE_N</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span>
|
||||
<span class="n">BLOCK_SIZE_M</span><span class="o">=</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span>
|
||||
<span class="n">BLOCK_SIZE_N</span><span class="o">=</span><span class="n">BLOCK_SIZE_N</span><span class="p">,</span>
|
||||
<span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="n">da</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dweight</span><span class="p">,</span> <span class="n">dbias</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="kc">None</span><span class="p">,</span>
|
||||
<span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="kc">None</span><span class="p">,</span>
|
||||
<span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="n">da</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dweight</span><span class="p">,</span> <span class="n">dbias</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">layer_norm</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">normalized_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">):</span>
|
||||
@@ -537,7 +543,7 @@ to download the full example code</p>
|
||||
<span class="n">bench_layer_norm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">save_path</span><span class="o">=</span><span class="s1">'.'</span><span class="p">,</span> <span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 5 minutes 24.641 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 5 minutes 32.552 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-05-layer-norm-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">05-layer-norm.py</span></code></a></p>
|
||||
@@ -555,7 +561,7 @@ to download the full example code</p>
|
||||
</div>
|
||||
<footer>
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
<a href="../../python-api/triton.html" class="btn btn-neutral float-right" title="triton" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
<a href="06-fused-attention.html" class="btn btn-neutral float-right" title="Fused Attention" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
<a href="04-low-memory-dropout.html" class="btn btn-neutral float-left" title="Low-Memory Dropout" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||||
</div>
|
||||
|
||||
|
623
master/getting-started/tutorials/06-fused-attention.html
Normal file
@@ -0,0 +1,623 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html class="writer-html5" lang="en" >
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<title>Fused Attention — Triton documentation</title>
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<!--[if lt IE 9]>
|
||||
<script src="../../_static/js/html5shiv.min.js"></script>
|
||||
<![endif]-->
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
|
||||
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
||||
|
||||
|
||||
<link rel="index" title="Index" href="../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../search.html" />
|
||||
<link rel="next" title="Libdevice function" href="07-libdevice-function.html" />
|
||||
<link rel="prev" title="Layer Normalization" href="05-layer-norm.html" />
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search" >
|
||||
|
||||
|
||||
|
||||
<a href="../../index.html" class="icon icon-home"> Triton
|
||||
|
||||
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
|
||||
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
|
||||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Fused Attention</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../index.html">Triton</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../index.html" class="icon icon-home"></a> »</li>
|
||||
|
||||
<li><a href="index.html">Tutorials</a> »</li>
|
||||
|
||||
<li>Fused Attention</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
|
||||
<a href="../../_sources/getting-started/tutorials/06-fused-attention.rst.txt" rel="nofollow"> View page source</a>
|
||||
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<div class="sphx-glr-download-link-note admonition note">
|
||||
<p class="admonition-title">Note</p>
|
||||
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-06-fused-attention-py"><span class="std std-ref">here</span></a>
|
||||
to download the full example code</p>
|
||||
</div>
|
||||
<div class="sphx-glr-example-title section" id="fused-attention">
|
||||
<span id="sphx-glr-getting-started-tutorials-06-fused-attention-py"></span><h1>Fused Attention<a class="headerlink" href="#fused-attention" title="Permalink to this headline">¶</a></h1>
|
||||
<p>This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., <a class="reference external" href="https://arxiv.org/pdf/2205.14135v2.pdf">https://arxiv.org/pdf/2205.14135v2.pdf</a>; Rabe and Staats <a class="reference external" href="https://arxiv.org/pdf/2112.05682v2.pdf">https://arxiv.org/pdf/2112.05682v2.pdf</a>)</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">pytest</span>
|
||||
<span class="kn">import</span> <span class="nn">torch</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">triton</span>
|
||||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||||
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">_fwd_kernel</span><span class="p">(</span>
|
||||
<span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">,</span>
|
||||
<span class="n">TMP</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="c1"># NOTE: TMP is a scratchpad buffer to workaround a compiler bug</span>
|
||||
<span class="n">Out</span><span class="p">,</span>
|
||||
<span class="n">stride_qz</span><span class="p">,</span> <span class="n">stride_qh</span><span class="p">,</span> <span class="n">stride_qm</span><span class="p">,</span> <span class="n">stride_qk</span><span class="p">,</span>
|
||||
<span class="n">stride_kz</span><span class="p">,</span> <span class="n">stride_kh</span><span class="p">,</span> <span class="n">stride_kn</span><span class="p">,</span> <span class="n">stride_kk</span><span class="p">,</span>
|
||||
<span class="n">stride_vz</span><span class="p">,</span> <span class="n">stride_vh</span><span class="p">,</span> <span class="n">stride_vk</span><span class="p">,</span> <span class="n">stride_vn</span><span class="p">,</span>
|
||||
<span class="n">stride_oz</span><span class="p">,</span> <span class="n">stride_oh</span><span class="p">,</span> <span class="n">stride_om</span><span class="p">,</span> <span class="n">stride_on</span><span class="p">,</span>
|
||||
<span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span>
|
||||
<span class="n">BLOCK_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
|
||||
<span class="n">BLOCK_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
|
||||
<span class="p">):</span>
|
||||
<span class="n">start_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">off_hz</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="c1"># initialize offsets</span>
|
||||
<span class="n">offs_m</span> <span class="o">=</span> <span class="n">start_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||||
<span class="n">offs_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
|
||||
<span class="n">offs_d</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">)</span>
|
||||
<span class="n">off_q</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">stride_qh</span> <span class="o">+</span> <span class="n">offs_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_d</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span>
|
||||
<span class="n">off_k</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">stride_qh</span> <span class="o">+</span> <span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_kn</span> <span class="o">+</span> <span class="n">offs_d</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_kk</span>
|
||||
<span class="n">off_v</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">stride_qh</span> <span class="o">+</span> <span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_d</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span>
|
||||
<span class="c1"># Initialize pointers to Q, K, V</span>
|
||||
<span class="n">q_ptrs</span> <span class="o">=</span> <span class="n">Q</span> <span class="o">+</span> <span class="n">off_q</span>
|
||||
<span class="n">k_ptrs</span> <span class="o">=</span> <span class="n">K</span> <span class="o">+</span> <span class="n">off_k</span>
|
||||
<span class="n">v_ptrs</span> <span class="o">=</span> <span class="n">V</span> <span class="o">+</span> <span class="n">off_v</span>
|
||||
<span class="c1"># initialize pointer to m and l</span>
|
||||
<span class="n">t_ptrs</span> <span class="o">=</span> <span class="n">TMP</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span> <span class="o">+</span> <span class="n">offs_m</span>
|
||||
<span class="n">m_i</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">-</span> <span class="nb">float</span><span class="p">(</span><span class="s2">"inf"</span><span class="p">)</span>
|
||||
<span class="n">l_i</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">acc</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="c1"># load q: it will stay in SRAM throughout</span>
|
||||
<span class="n">q</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">q_ptrs</span><span class="p">)</span>
|
||||
<span class="c1"># loop over k, v and update accumulator</span>
|
||||
<span class="k">for</span> <span class="n">start_n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="p">(</span><span class="n">start_m</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">):</span>
|
||||
<span class="n">start_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">multiple_of</span><span class="p">(</span><span class="n">start_n</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
|
||||
<span class="c1"># -- compute qk ----</span>
|
||||
<span class="n">k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">k_ptrs</span> <span class="o">+</span> <span class="n">start_n</span> <span class="o">*</span> <span class="n">stride_kn</span><span class="p">)</span>
|
||||
<span class="n">qk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">qk</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">trans_b</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">qk</span> <span class="o">*=</span> <span class="n">sm_scale</span>
|
||||
<span class="n">qk</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">offs_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">>=</span> <span class="p">(</span><span class="n">start_n</span> <span class="o">+</span> <span class="n">offs_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]),</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s2">"-inf"</span><span class="p">))</span>
|
||||
<span class="c1"># -- compute m_ij, p, l_ij</span>
|
||||
<span class="n">m_ij</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">qk</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">p</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">qk</span> <span class="o">-</span> <span class="n">m_ij</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span>
|
||||
<span class="n">l_ij</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="c1"># -- update m_i and l_i</span>
|
||||
<span class="n">m_i_new</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">m_i</span><span class="p">,</span> <span class="n">m_ij</span><span class="p">)</span>
|
||||
<span class="n">alpha</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">m_i</span> <span class="o">-</span> <span class="n">m_i_new</span><span class="p">)</span>
|
||||
<span class="n">beta</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">m_ij</span> <span class="o">-</span> <span class="n">m_i_new</span><span class="p">)</span>
|
||||
<span class="n">l_i_new</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">l_i</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">l_ij</span>
|
||||
<span class="c1"># -- update output accumulator --</span>
|
||||
<span class="c1"># scale p</span>
|
||||
<span class="n">p_scale</span> <span class="o">=</span> <span class="n">beta</span> <span class="o">/</span> <span class="n">l_i_new</span>
|
||||
<span class="n">p</span> <span class="o">=</span> <span class="n">p</span> <span class="o">*</span> <span class="n">p_scale</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="c1"># scale acc</span>
|
||||
<span class="n">acc_scale</span> <span class="o">=</span> <span class="n">l_i</span> <span class="o">/</span> <span class="n">l_i_new</span> <span class="o">*</span> <span class="n">alpha</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">t_ptrs</span><span class="p">,</span> <span class="n">acc_scale</span><span class="p">)</span>
|
||||
<span class="n">acc_scale</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">t_ptrs</span><span class="p">)</span> <span class="c1"># BUG: have to store and immediately load</span>
|
||||
<span class="n">acc</span> <span class="o">=</span> <span class="n">acc</span> <span class="o">*</span> <span class="n">acc_scale</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="c1"># update acc</span>
|
||||
<span class="n">v</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">v_ptrs</span> <span class="o">+</span> <span class="n">start_n</span> <span class="o">*</span> <span class="n">stride_vk</span><span class="p">)</span>
|
||||
<span class="n">p</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||||
<span class="n">acc</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
|
||||
<span class="c1"># update m_i and l_i</span>
|
||||
<span class="n">l_i</span> <span class="o">=</span> <span class="n">l_i_new</span>
|
||||
<span class="n">m_i</span> <span class="o">=</span> <span class="n">m_i_new</span>
|
||||
<span class="c1"># rematerialize offsets to save registers</span>
|
||||
<span class="n">start_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">offs_m</span> <span class="o">=</span> <span class="n">start_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||||
<span class="c1"># write back l and m</span>
|
||||
<span class="n">l_ptrs</span> <span class="o">=</span> <span class="n">L</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span> <span class="o">+</span> <span class="n">offs_m</span>
|
||||
<span class="n">m_ptrs</span> <span class="o">=</span> <span class="n">M</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span> <span class="o">+</span> <span class="n">offs_m</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">l_ptrs</span><span class="p">,</span> <span class="n">l_i</span><span class="p">)</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">m_ptrs</span><span class="p">,</span> <span class="n">m_i</span><span class="p">)</span>
|
||||
<span class="c1"># initialize pointers to output</span>
|
||||
<span class="n">offs_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">)</span>
|
||||
<span class="n">off_o</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">stride_oh</span> <span class="o">+</span> <span class="n">offs_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_om</span> <span class="o">+</span> <span class="n">offs_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_on</span>
|
||||
<span class="n">out_ptrs</span> <span class="o">=</span> <span class="n">Out</span> <span class="o">+</span> <span class="n">off_o</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">out_ptrs</span><span class="p">,</span> <span class="n">acc</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">_bwd_preprocess</span><span class="p">(</span>
|
||||
<span class="n">Out</span><span class="p">,</span> <span class="n">DO</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span>
|
||||
<span class="n">NewDO</span><span class="p">,</span> <span class="n">Delta</span><span class="p">,</span>
|
||||
<span class="n">BLOCK_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
|
||||
<span class="p">):</span>
|
||||
<span class="n">off_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||||
<span class="n">off_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">)</span>
|
||||
<span class="c1"># load</span>
|
||||
<span class="n">o</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Out</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">D_HEAD</span> <span class="o">+</span> <span class="n">off_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">do</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DO</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">D_HEAD</span> <span class="o">+</span> <span class="n">off_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">denom</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">L</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="c1"># compute</span>
|
||||
<span class="n">do</span> <span class="o">=</span> <span class="n">do</span> <span class="o">/</span> <span class="n">denom</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="n">delta</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">o</span> <span class="o">*</span> <span class="n">do</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="c1"># write-back</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">NewDO</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">D_HEAD</span> <span class="o">+</span> <span class="n">off_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:],</span> <span class="n">do</span><span class="p">)</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Delta</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">,</span> <span class="n">delta</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">_bwd_kernel</span><span class="p">(</span>
|
||||
<span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">,</span> <span class="n">Out</span><span class="p">,</span> <span class="n">DO</span><span class="p">,</span>
|
||||
<span class="n">DQ</span><span class="p">,</span> <span class="n">DK</span><span class="p">,</span> <span class="n">DV</span><span class="p">,</span>
|
||||
<span class="n">L</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span>
|
||||
<span class="n">D</span><span class="p">,</span>
|
||||
<span class="n">stride_qz</span><span class="p">,</span> <span class="n">stride_qh</span><span class="p">,</span> <span class="n">stride_qm</span><span class="p">,</span> <span class="n">stride_qk</span><span class="p">,</span>
|
||||
<span class="n">stride_kz</span><span class="p">,</span> <span class="n">stride_kh</span><span class="p">,</span> <span class="n">stride_kn</span><span class="p">,</span> <span class="n">stride_kk</span><span class="p">,</span>
|
||||
<span class="n">stride_vz</span><span class="p">,</span> <span class="n">stride_vh</span><span class="p">,</span> <span class="n">stride_vk</span><span class="p">,</span> <span class="n">stride_vn</span><span class="p">,</span>
|
||||
<span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span>
|
||||
<span class="n">num_block</span><span class="p">,</span>
|
||||
<span class="n">BLOCK_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
|
||||
<span class="n">BLOCK_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
|
||||
<span class="p">):</span>
|
||||
<span class="n">off_hz</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">off_z</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">//</span> <span class="n">H</span>
|
||||
<span class="n">off_h</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">%</span> <span class="n">H</span>
|
||||
<span class="c1"># offset pointers for batch/head</span>
|
||||
<span class="n">Q</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
|
||||
<span class="n">K</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
|
||||
<span class="n">V</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
|
||||
<span class="n">DO</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
|
||||
<span class="n">DQ</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
|
||||
<span class="n">DK</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
|
||||
<span class="n">DV</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
|
||||
<span class="k">for</span> <span class="n">start_n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num_block</span><span class="p">):</span>
|
||||
<span class="n">lo</span> <span class="o">=</span> <span class="n">start_n</span> <span class="o">*</span> <span class="n">BLOCK_M</span>
|
||||
<span class="c1"># initialize row/col offsets</span>
|
||||
<span class="n">offs_qm</span> <span class="o">=</span> <span class="n">lo</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||||
<span class="n">offs_n</span> <span class="o">=</span> <span class="n">start_n</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||||
<span class="n">offs_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
|
||||
<span class="n">offs_k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">)</span>
|
||||
<span class="c1"># initialize pointers to value-like data</span>
|
||||
<span class="n">q_ptrs</span> <span class="o">=</span> <span class="n">Q</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_qm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
|
||||
<span class="n">k_ptrs</span> <span class="o">=</span> <span class="n">K</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_kn</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_kk</span><span class="p">)</span>
|
||||
<span class="n">v_ptrs</span> <span class="o">=</span> <span class="n">V</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
|
||||
<span class="n">do_ptrs</span> <span class="o">=</span> <span class="n">DO</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_qm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
|
||||
<span class="n">dq_ptrs</span> <span class="o">=</span> <span class="n">DQ</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_qm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
|
||||
<span class="c1"># pointer to row-wise quantities in value-like data</span>
|
||||
<span class="n">D_ptrs</span> <span class="o">=</span> <span class="n">D</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span>
|
||||
<span class="n">m_ptrs</span> <span class="o">=</span> <span class="n">M</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span>
|
||||
<span class="c1"># initialize dv amd dk</span>
|
||||
<span class="n">dv</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">dk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="c1"># k and v stay in SRAM throughout</span>
|
||||
<span class="n">k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">k_ptrs</span><span class="p">)</span>
|
||||
<span class="n">v</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">v_ptrs</span><span class="p">)</span>
|
||||
<span class="c1"># loop over rows</span>
|
||||
<span class="k">for</span> <span class="n">start_m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">lo</span><span class="p">,</span> <span class="n">num_block</span> <span class="o">*</span> <span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">):</span>
|
||||
<span class="n">offs_m_curr</span> <span class="o">=</span> <span class="n">start_m</span> <span class="o">+</span> <span class="n">offs_m</span>
|
||||
<span class="c1"># load q, k, v, do on-chip</span>
|
||||
<span class="n">q</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">q_ptrs</span><span class="p">)</span>
|
||||
<span class="c1"># recompute p = softmax(qk, dim=-1).T</span>
|
||||
<span class="c1"># NOTE: `do` is pre-divided by `l`; no normalization here</span>
|
||||
<span class="n">qk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">trans_b</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">qk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">offs_m_curr</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">>=</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]),</span> <span class="n">qk</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s2">"-inf"</span><span class="p">))</span>
|
||||
<span class="n">m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">m_ptrs</span> <span class="o">+</span> <span class="n">offs_m_curr</span><span class="p">)</span>
|
||||
<span class="n">p</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">qk</span> <span class="o">*</span> <span class="n">sm_scale</span> <span class="o">-</span> <span class="n">m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span>
|
||||
<span class="c1"># compute dv</span>
|
||||
<span class="n">do</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">do_ptrs</span><span class="p">)</span>
|
||||
<span class="n">dv</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">),</span> <span class="n">do</span><span class="p">,</span> <span class="n">trans_a</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="c1"># compute dp = dot(v, do)</span>
|
||||
<span class="n">Di</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">D_ptrs</span> <span class="o">+</span> <span class="n">offs_m_curr</span><span class="p">)</span>
|
||||
<span class="n">dp</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">-</span> <span class="n">Di</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="n">dp</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">do</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">trans_b</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="c1"># compute ds = p * (dp - delta[:, None])</span>
|
||||
<span class="n">ds</span> <span class="o">=</span> <span class="n">p</span> <span class="o">*</span> <span class="n">dp</span> <span class="o">*</span> <span class="n">sm_scale</span>
|
||||
<span class="c1"># compute dk = dot(ds.T, q)</span>
|
||||
<span class="n">dk</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">ds</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">),</span> <span class="n">q</span><span class="p">,</span> <span class="n">trans_a</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="c1"># # compute dq</span>
|
||||
<span class="n">dq</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">dq_ptrs</span><span class="p">,</span> <span class="n">eviction_policy</span><span class="o">=</span><span class="s2">"evict_last"</span><span class="p">)</span>
|
||||
<span class="n">dq</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">ds</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">),</span> <span class="n">k</span><span class="p">)</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">dq_ptrs</span><span class="p">,</span> <span class="n">dq</span><span class="p">,</span> <span class="n">eviction_policy</span><span class="o">=</span><span class="s2">"evict_last"</span><span class="p">)</span>
|
||||
<span class="c1"># # increment pointers</span>
|
||||
<span class="n">dq_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_M</span> <span class="o">*</span> <span class="n">stride_qm</span>
|
||||
<span class="n">q_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_M</span> <span class="o">*</span> <span class="n">stride_qm</span>
|
||||
<span class="n">do_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_M</span> <span class="o">*</span> <span class="n">stride_qm</span>
|
||||
<span class="c1"># write-back</span>
|
||||
<span class="n">dv_ptrs</span> <span class="o">=</span> <span class="n">DV</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
|
||||
<span class="n">dk_ptrs</span> <span class="o">=</span> <span class="n">DK</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_kn</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_kk</span><span class="p">)</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">dv_ptrs</span><span class="p">,</span> <span class="n">dv</span><span class="p">)</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">dk_ptrs</span><span class="p">,</span> <span class="n">dk</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">_attention</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
|
||||
|
||||
<span class="nd">@staticmethod</span>
|
||||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">):</span>
|
||||
<span class="n">BLOCK</span> <span class="o">=</span> <span class="mi">128</span>
|
||||
<span class="c1"># shape constraints</span>
|
||||
<span class="n">Lq</span><span class="p">,</span> <span class="n">Lk</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="k">assert</span> <span class="n">Lq</span> <span class="o">==</span> <span class="n">Lk</span>
|
||||
<span class="n">o</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">q</span><span class="p">)</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">BLOCK</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||||
<span class="n">tmp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="n">device</span><span class="o">=</span><span class="n">q</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">L</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="n">device</span><span class="o">=</span><span class="n">q</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">m</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="n">device</span><span class="o">=</span><span class="n">q</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">_fwd_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
|
||||
<span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">,</span>
|
||||
<span class="n">tmp</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span>
|
||||
<span class="n">o</span><span class="p">,</span>
|
||||
<span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
|
||||
<span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
|
||||
<span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
|
||||
<span class="n">o</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">o</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">o</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">o</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
|
||||
<span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span>
|
||||
<span class="n">BLOCK_M</span><span class="o">=</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="o">=</span><span class="n">BLOCK</span><span class="p">,</span>
|
||||
<span class="n">BLOCK_DMODEL</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span>
|
||||
<span class="n">num_stages</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">o</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">m</span><span class="p">)</span>
|
||||
<span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK</span> <span class="o">=</span> <span class="n">BLOCK</span>
|
||||
<span class="n">ctx</span><span class="o">.</span><span class="n">grid</span> <span class="o">=</span> <span class="n">grid</span>
|
||||
<span class="n">ctx</span><span class="o">.</span><span class="n">sm_scale</span> <span class="o">=</span> <span class="n">sm_scale</span>
|
||||
<span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_DMODEL</span> <span class="o">=</span> <span class="mi">64</span>
|
||||
<span class="k">return</span> <span class="n">o</span>
|
||||
|
||||
<span class="nd">@staticmethod</span>
|
||||
<span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">do</span><span class="p">):</span>
|
||||
<span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">o</span><span class="p">,</span> <span class="n">l</span><span class="p">,</span> <span class="n">m</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span>
|
||||
<span class="n">do</span> <span class="o">=</span> <span class="n">do</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
|
||||
<span class="n">dq</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">dk</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
|
||||
<span class="n">dv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
|
||||
<span class="n">do_scaled</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">do</span><span class="p">)</span>
|
||||
<span class="n">delta</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">l</span><span class="p">)</span>
|
||||
<span class="n">_bwd_preprocess</span><span class="p">[(</span><span class="n">ctx</span><span class="o">.</span><span class="n">grid</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">ctx</span><span class="o">.</span><span class="n">grid</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="p">)](</span>
|
||||
<span class="n">o</span><span class="p">,</span> <span class="n">do</span><span class="p">,</span> <span class="n">l</span><span class="p">,</span>
|
||||
<span class="n">do_scaled</span><span class="p">,</span> <span class="n">delta</span><span class="p">,</span>
|
||||
<span class="n">BLOCK_M</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_DMODEL</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">_bwd_kernel</span><span class="p">[(</span><span class="n">ctx</span><span class="o">.</span><span class="n">grid</span><span class="p">[</span><span class="mi">1</span><span class="p">],)](</span>
|
||||
<span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">ctx</span><span class="o">.</span><span class="n">sm_scale</span><span class="p">,</span>
|
||||
<span class="n">o</span><span class="p">,</span> <span class="n">do_scaled</span><span class="p">,</span>
|
||||
<span class="n">dq</span><span class="p">,</span> <span class="n">dk</span><span class="p">,</span> <span class="n">dv</span><span class="p">,</span>
|
||||
<span class="n">l</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span>
|
||||
<span class="n">delta</span><span class="p">,</span>
|
||||
<span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
|
||||
<span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
|
||||
<span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
|
||||
<span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span>
|
||||
<span class="n">ctx</span><span class="o">.</span><span class="n">grid</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||||
<span class="n">BLOCK_M</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">,</span>
|
||||
<span class="n">BLOCK_DMODEL</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_DMODEL</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span>
|
||||
<span class="n">num_stages</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">dq</span><span class="p">,</span> <span class="n">dk</span><span class="p">,</span> <span class="n">dv</span><span class="p">,</span> <span class="kc">None</span>
|
||||
|
||||
|
||||
<span class="n">attention</span> <span class="o">=</span> <span class="n">_attention</span><span class="o">.</span><span class="n">apply</span>
|
||||
|
||||
|
||||
<span class="nd">@pytest</span><span class="o">.</span><span class="n">mark</span><span class="o">.</span><span class="n">parametrize</span><span class="p">(</span><span class="s1">'Z, H, N_CTX, D_HEAD'</span><span class="p">,</span> <span class="p">[(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2048</span><span class="p">,</span> <span class="mi">64</span><span class="p">)])</span>
|
||||
<span class="k">def</span> <span class="nf">test_op</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">):</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">20</span><span class="p">)</span>
|
||||
<span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">"cuda"</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">.5</span><span class="p">)</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">()</span>
|
||||
<span class="n">k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">"cuda"</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">.5</span><span class="p">)</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">()</span>
|
||||
<span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">"cuda"</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">.5</span><span class="p">)</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">()</span>
|
||||
<span class="n">sm_scale</span> <span class="o">=</span> <span class="mf">0.3</span>
|
||||
<span class="n">dout</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">q</span><span class="p">)</span>
|
||||
<span class="c1"># reference implementation</span>
|
||||
<span class="n">M</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tril</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">N_CTX</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s2">"cuda"</span><span class="p">))</span>
|
||||
<span class="n">p</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="o">*</span> <span class="n">sm_scale</span>
|
||||
<span class="k">for</span> <span class="n">z</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">Z</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">h</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">H</span><span class="p">):</span>
|
||||
<span class="n">p</span><span class="p">[:,</span> <span class="p">:,</span> <span class="n">M</span> <span class="o">==</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s2">"-inf"</span><span class="p">)</span>
|
||||
<span class="n">p</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">half</span><span class="p">()</span>
|
||||
<span class="n">ref_out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
|
||||
<span class="n">ref_out</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dout</span><span class="p">)</span>
|
||||
<span class="n">ref_dv</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
|
||||
<span class="n">ref_dk</span><span class="p">,</span> <span class="n">k</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
|
||||
<span class="n">ref_dq</span><span class="p">,</span> <span class="n">q</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
|
||||
<span class="c1"># triton implementation</span>
|
||||
<span class="n">tri_out</span> <span class="o">=</span> <span class="n">attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">)</span>
|
||||
<span class="n">tri_out</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dout</span><span class="p">)</span>
|
||||
<span class="n">tri_dv</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
|
||||
<span class="n">tri_dk</span><span class="p">,</span> <span class="n">k</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
|
||||
<span class="n">tri_dq</span><span class="p">,</span> <span class="n">q</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
|
||||
<span class="c1"># compare</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">ref_out</span><span class="p">,</span> <span class="n">tri_out</span><span class="p">)</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">ref_dv</span><span class="p">,</span> <span class="n">tri_dv</span><span class="p">)</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">ref_dk</span><span class="p">,</span> <span class="n">tri_dk</span><span class="p">)</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">ref_dq</span><span class="p">,</span> <span class="n">tri_dq</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="k">try</span><span class="p">:</span>
|
||||
<span class="kn">from</span> <span class="nn">flash_attn.flash_attn_interface</span> <span class="kn">import</span> <span class="n">flash_attn_func</span>
|
||||
<span class="n">HAS_FLASH</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="k">except</span> <span class="ne">BaseException</span><span class="p">:</span>
|
||||
<span class="n">HAS_FLASH</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
<span class="n">BATCH</span><span class="p">,</span> <span class="n">N_HEADS</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span> <span class="o">=</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">48</span><span class="p">,</span> <span class="mi">4096</span><span class="p">,</span> <span class="mi">64</span>
|
||||
<span class="c1"># vary seq length for fixed head and batch=4</span>
|
||||
<span class="n">configs</span> <span class="o">=</span> <span class="p">[</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
|
||||
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'N_CTX'</span><span class="p">],</span>
|
||||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">16</span><span class="p">)],</span>
|
||||
<span class="n">line_arg</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span>
|
||||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'triton'</span><span class="p">]</span> <span class="o">+</span> <span class="p">([</span><span class="s1">'flash'</span><span class="p">]</span> <span class="k">if</span> <span class="n">HAS_FLASH</span> <span class="k">else</span> <span class="p">[]),</span>
|
||||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'Triton'</span><span class="p">]</span> <span class="o">+</span> <span class="p">([</span><span class="s1">'Flash'</span><span class="p">]</span> <span class="k">if</span> <span class="n">HAS_FLASH</span> <span class="k">else</span> <span class="p">[]),</span>
|
||||
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">'red'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">)],</span>
|
||||
<span class="n">ylabel</span><span class="o">=</span><span class="s1">'ms'</span><span class="p">,</span>
|
||||
<span class="n">plot_name</span><span class="o">=</span><span class="sa">f</span><span class="s1">'fused-attention-batch</span><span class="si">{</span><span class="n">BATCH</span><span class="si">}</span><span class="s1">-head</span><span class="si">{</span><span class="n">N_HEADS</span><span class="si">}</span><span class="s1">-d</span><span class="si">{</span><span class="n">D_HEAD</span><span class="si">}</span><span class="s1">-</span><span class="si">{</span><span class="n">mode</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="p">{</span><span class="s1">'H'</span><span class="p">:</span> <span class="n">N_HEADS</span><span class="p">,</span> <span class="s1">'BATCH'</span><span class="p">:</span> <span class="n">BATCH</span><span class="p">,</span> <span class="s1">'D_HEAD'</span><span class="p">:</span> <span class="n">D_HEAD</span><span class="p">,</span> <span class="s1">'dtype'</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="s1">'mode'</span><span class="p">:</span> <span class="n">mode</span><span class="p">}</span>
|
||||
<span class="p">)</span> <span class="k">for</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'bwd'</span><span class="p">]]</span>
|
||||
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span><span class="n">configs</span><span class="p">)</span>
|
||||
<span class="k">def</span> <span class="nf">bench_flash_attention</span><span class="p">(</span><span class="n">BATCH</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">,</span> <span class="n">mode</span><span class="p">,</span> <span class="n">provider</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">"cuda"</span><span class="p">):</span>
|
||||
<span class="k">assert</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'fwd'</span><span class="p">,</span> <span class="s1">'bwd'</span><span class="p">]</span>
|
||||
<span class="n">warmup</span> <span class="o">=</span> <span class="mi">25</span>
|
||||
<span class="n">rep</span> <span class="o">=</span> <span class="mi">100</span>
|
||||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s2">"triton"</span><span class="p">:</span>
|
||||
<span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">BATCH</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">"cuda"</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">BATCH</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">"cuda"</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">BATCH</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">"cuda"</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">sm_scale</span> <span class="o">=</span> <span class="mf">1.3</span>
|
||||
<span class="n">fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">'bwd'</span><span class="p">:</span>
|
||||
<span class="n">o</span> <span class="o">=</span> <span class="n">fn</span><span class="p">()</span>
|
||||
<span class="n">do</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>
|
||||
<span class="n">fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">do</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">percentiles</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="n">warmup</span><span class="p">,</span> <span class="n">rep</span><span class="o">=</span><span class="n">rep</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">ms</span>
|
||||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s2">"flash"</span><span class="p">:</span>
|
||||
<span class="n">lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">BATCH</span><span class="p">,),</span> <span class="n">fill_value</span><span class="o">=</span><span class="n">N_CTX</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="n">cu_seqlens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BATCH</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="n">cu_seqlens</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">=</span> <span class="n">lengths</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">BATCH</span> <span class="o">*</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">flash_attn_func</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="n">cu_seqlens</span><span class="p">,</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">causal</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">'bwd'</span><span class="p">:</span>
|
||||
<span class="n">o</span> <span class="o">=</span> <span class="n">fn</span><span class="p">()</span>
|
||||
<span class="n">do</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>
|
||||
<span class="n">fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">do</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">percentiles</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="n">warmup</span><span class="p">,</span> <span class="n">rep</span><span class="o">=</span><span class="n">rep</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">ms</span>
|
||||
|
||||
<span class="c1"># only works on A100 at the moment</span>
|
||||
<span class="c1"># bench_flash_attention.run(save_path='.', print_data=True)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 0.072 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-06-fused-attention-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/54a35f6ec55f9746935b9566fb6bb1df/06-fused-attention.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">06-fused-attention.py</span></code></a></p>
|
||||
</div>
|
||||
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">06-fused-attention.ipynb</span></code></a></p>
|
||||
</div>
|
||||
</div>
|
||||
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
<a href="07-libdevice-function.html" class="btn btn-neutral float-right" title="Libdevice function" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
<a href="05-layer-norm.html" class="btn btn-neutral float-left" title="Layer Normalization" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||||
</div>
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2020, Philippe Tillet.
|
||||
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||||
|
||||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||||
|
||||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
|
||||
<span class="rst-current-version" data-toggle="rst-current-version">
|
||||
<span class="fa fa-book"> Other Versions</span>
|
||||
v: master
|
||||
<span class="fa fa-caret-down"></span>
|
||||
</span>
|
||||
<div class="rst-other-versions">
|
||||
<dl>
|
||||
<dt>Tags</dt>
|
||||
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
|
||||
</dl>
|
||||
<dl>
|
||||
<dt>Branches</dt>
|
||||
<dd><a href="06-fused-attention.html">master</a></dd>
|
||||
</dl>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</body>
|
||||
</html>
|
357
master/getting-started/tutorials/07-libdevice-function.html
Normal file
@@ -0,0 +1,357 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html class="writer-html5" lang="en" >
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<title>Libdevice function — Triton documentation</title>
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<!--[if lt IE 9]>
|
||||
<script src="../../_static/js/html5shiv.min.js"></script>
|
||||
<![endif]-->
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
|
||||
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
||||
|
||||
|
||||
<link rel="index" title="Index" href="../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../search.html" />
|
||||
<link rel="next" title="triton" href="../../python-api/triton.html" />
|
||||
<link rel="prev" title="Fused Attention" href="06-fused-attention.html" />
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search" >
|
||||
|
||||
|
||||
|
||||
<a href="../../index.html" class="icon icon-home"> Triton
|
||||
|
||||
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
|
||||
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
|
||||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Libdevice function</a><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#asin-kernel">asin Kernel</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#using-the-default-libdevice-library-path">Using the default libdevice library path</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#customize-the-libdevice-library-path">Customize the libdevice library path</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../index.html">Triton</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../index.html" class="icon icon-home"></a> »</li>
|
||||
|
||||
<li><a href="index.html">Tutorials</a> »</li>
|
||||
|
||||
<li>Libdevice function</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
|
||||
<a href="../../_sources/getting-started/tutorials/07-libdevice-function.rst.txt" rel="nofollow"> View page source</a>
|
||||
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<div class="sphx-glr-download-link-note admonition note">
|
||||
<p class="admonition-title">Note</p>
|
||||
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-07-libdevice-function-py"><span class="std std-ref">here</span></a>
|
||||
to download the full example code</p>
|
||||
</div>
|
||||
<div class="sphx-glr-example-title section" id="libdevice-function">
|
||||
<span id="sphx-glr-getting-started-tutorials-07-libdevice-function-py"></span><h1>Libdevice function<a class="headerlink" href="#libdevice-function" title="Permalink to this headline">¶</a></h1>
|
||||
<p>Triton can invoke a custom function from an external library.
|
||||
In this example, we will use the <cite>libdevice</cite> library to apply <cite>asin</cite> on a tensor.
|
||||
Please refer to <a class="reference external" href="https://docs.nvidia.com/cuda/libdevice-users-guide/index.html">https://docs.nvidia.com/cuda/libdevice-users-guide/index.html</a> regarding the semantics of all available libdevice functions.</p>
|
||||
<p>In <cite>trition/language/libdevice.py</cite>, we try to aggregate functions with the same computation but different data types together.
|
||||
For example, both <cite>__nv_asin</cite> and <cite>__nvasinf</cite> calculate the principal value of the arc sine of the input, but <cite>__nv_asin</cite> operates on <cite>double</cite> and <cite>__nv_asinf</cite> operates on <cite>float</cite>.
|
||||
Using triton, you can simply call <cite>tl.libdevice.asinf</cite>.
|
||||
triton automatically selects the correct underlying device function to invoke based on input and output types.</p>
|
||||
<div class="section" id="asin-kernel">
|
||||
<h2>asin Kernel<a class="headerlink" href="#asin-kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">triton</span>
|
||||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||||
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">asin_kernel</span><span class="p">(</span>
|
||||
<span class="n">x_ptr</span><span class="p">,</span>
|
||||
<span class="n">y_ptr</span><span class="p">,</span>
|
||||
<span class="n">n_elements</span><span class="p">,</span>
|
||||
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
|
||||
<span class="p">):</span>
|
||||
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
|
||||
<span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o"><</span> <span class="n">n_elements</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">libdevice</span><span class="o">.</span><span class="n">asin</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">y_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section" id="using-the-default-libdevice-library-path">
|
||||
<h2>Using the default libdevice library path<a class="headerlink" href="#using-the-default-libdevice-library-path" title="Permalink to this headline">¶</a></h2>
|
||||
<p>We can use the default libdevice library path encoded in <cite>triton/language/libdevice.py</cite></p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">size</span> <span class="o">=</span> <span class="mi">98432</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">output_triton</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">output_torch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">asin</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">is_cuda</span> <span class="ow">and</span> <span class="n">output_triton</span><span class="o">.</span><span class="n">is_cuda</span>
|
||||
<span class="n">n_elements</span> <span class="o">=</span> <span class="n">output_torch</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">n_elements</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE'</span><span class="p">]),)</span>
|
||||
<span class="n">asin_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">output_triton</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">output_torch</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">output_triton</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span>
|
||||
<span class="sa">f</span><span class="s1">'The maximum difference between torch and triton is '</span>
|
||||
<span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">output_torch</span> <span class="o">-</span> <span class="n">output_triton</span><span class="p">))</span><span class="si">}</span><span class="s1">'</span>
|
||||
<span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
|
||||
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
|
||||
The maximum difference between torch and triton is 2.384185791015625e-07
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section" id="customize-the-libdevice-library-path">
|
||||
<h2>Customize the libdevice library path<a class="headerlink" href="#customize-the-libdevice-library-path" title="Permalink to this headline">¶</a></h2>
|
||||
<p>We can also customize the libdevice library path by passing the path to the <cite>libdevice</cite> library to the <cite>asin</cite> kernel.</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">output_triton</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="n">asin_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">output_triton</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span>
|
||||
<span class="n">extern_libs</span><span class="o">=</span><span class="p">{</span><span class="s1">'libdevice'</span><span class="p">:</span> <span class="s1">'/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'</span><span class="p">})</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">output_torch</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">output_triton</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span>
|
||||
<span class="sa">f</span><span class="s1">'The maximum difference between torch and triton is '</span>
|
||||
<span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">output_torch</span> <span class="o">-</span> <span class="n">output_triton</span><span class="p">))</span><span class="si">}</span><span class="s1">'</span>
|
||||
<span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
|
||||
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
|
||||
The maximum difference between torch and triton is 2.384185791015625e-07
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 0.501 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-07-libdevice-function-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/3ff29f967ace7985da24aab10352fc76/07-libdevice-function.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">07-libdevice-function.py</span></code></a></p>
|
||||
</div>
|
||||
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/1bc2e471d2fb0ec017c4d1d0890db4e2/07-libdevice-function.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">07-libdevice-function.ipynb</span></code></a></p>
|
||||
</div>
|
||||
</div>
|
||||
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
<a href="../../python-api/triton.html" class="btn btn-neutral float-right" title="triton" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
<a href="06-fused-attention.html" class="btn btn-neutral float-left" title="Fused Attention" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||||
</div>
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2020, Philippe Tillet.
|
||||
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||||
|
||||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||||
|
||||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
|
||||
<span class="rst-current-version" data-toggle="rst-current-version">
|
||||
<span class="fa fa-book"> Other Versions</span>
|
||||
v: master
|
||||
<span class="fa fa-caret-down"></span>
|
||||
</span>
|
||||
<div class="rst-other-versions">
|
||||
<dl>
|
||||
<dt>Tags</dt>
|
||||
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
|
||||
</dl>
|
||||
<dl>
|
||||
<dt>Branches</dt>
|
||||
<dd><a href="07-libdevice-function.html">master</a></dd>
|
||||
</dl>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</body>
|
||||
</html>
|
@@ -101,6 +101,8 @@
|
||||
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
@@ -219,6 +221,18 @@ pip install -e <span class="s1">'./python[tutorials]'</span>
|
||||
</div>
|
||||
</div><div class="toctree-wrapper compound">
|
||||
</div>
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Fused Attention"><div class="figure align-default" id="id6">
|
||||
<img alt="Fused Attention" src="../../_images/sphx_glr_06-fused-attention_thumb.png" />
|
||||
<p class="caption"><span class="caption-text"><a class="reference internal" href="06-fused-attention.html#sphx-glr-getting-started-tutorials-06-fused-attention-py"><span class="std std-ref">Fused Attention</span></a></span><a class="headerlink" href="#id6" title="Permalink to this image">¶</a></p>
|
||||
</div>
|
||||
</div><div class="toctree-wrapper compound">
|
||||
</div>
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="In trition/language/libdevice.py, we try to aggregate functions with the same computation but d..."><div class="figure align-default" id="id7">
|
||||
<img alt="Libdevice function" src="../../_images/sphx_glr_07-libdevice-function_thumb.png" />
|
||||
<p class="caption"><span class="caption-text"><a class="reference internal" href="07-libdevice-function.html#sphx-glr-getting-started-tutorials-07-libdevice-function-py"><span class="std std-ref">Libdevice function</span></a></span><a class="headerlink" href="#id7" title="Permalink to this image">¶</a></p>
|
||||
</div>
|
||||
</div><div class="toctree-wrapper compound">
|
||||
</div>
|
||||
<div class="sphx-glr-clear"></div><div class="sphx-glr-footer class sphx-glr-footer-gallery docutils container">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">all</span> <span class="pre">examples</span> <span class="pre">in</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">tutorials_python.zip</span></code></a></p>
|
||||
|
@@ -174,7 +174,7 @@
|
||||
|
||||
<div class="section" id="computation-times">
|
||||
<span id="sphx-glr-getting-started-tutorials-sg-execution-times"></span><h1>Computation times<a class="headerlink" href="#computation-times" title="Permalink to this headline">¶</a></h1>
|
||||
<p><strong>16:10.599</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
|
||||
<p><strong>18:09.339</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
|
||||
<table class="docutils align-default">
|
||||
<colgroup>
|
||||
<col style="width: 85%" />
|
||||
@@ -183,23 +183,31 @@
|
||||
</colgroup>
|
||||
<tbody>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">Matrix Multiplication</span></a> (<code class="docutils literal notranslate"><span class="pre">03-matrix-multiplication.py</span></code>)</p></td>
|
||||
<td><p>05:52.578</p></td>
|
||||
<td><p>07:13.827</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="05-layer-norm.html#sphx-glr-getting-started-tutorials-05-layer-norm-py"><span class="std std-ref">Layer Normalization</span></a> (<code class="docutils literal notranslate"><span class="pre">05-layer-norm.py</span></code>)</p></td>
|
||||
<td><p>05:24.641</p></td>
|
||||
<td><p>05:32.552</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td>
|
||||
<td><p>03:18.076</p></td>
|
||||
<td><p>03:32.089</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td>
|
||||
<td><p>01:34.829</p></td>
|
||||
<td><p>01:50.020</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="04-low-memory-dropout.html#sphx-glr-getting-started-tutorials-04-low-memory-dropout-py"><span class="std std-ref">Low-Memory Dropout</span></a> (<code class="docutils literal notranslate"><span class="pre">04-low-memory-dropout.py</span></code>)</p></td>
|
||||
<td><p>00:00.476</p></td>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="07-libdevice-function.html#sphx-glr-getting-started-tutorials-07-libdevice-function-py"><span class="std std-ref">Libdevice function</span></a> (<code class="docutils literal notranslate"><span class="pre">07-libdevice-function.py</span></code>)</p></td>
|
||||
<td><p>00:00.501</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="04-low-memory-dropout.html#sphx-glr-getting-started-tutorials-04-low-memory-dropout-py"><span class="std std-ref">Low-Memory Dropout</span></a> (<code class="docutils literal notranslate"><span class="pre">04-low-memory-dropout.py</span></code>)</p></td>
|
||||
<td><p>00:00.279</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="06-fused-attention.html#sphx-glr-getting-started-tutorials-06-fused-attention-py"><span class="std std-ref">Fused Attention</span></a> (<code class="docutils literal notranslate"><span class="pre">06-fused-attention.py</span></code>)</p></td>
|
||||
<td><p>00:00.072</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
|
@@ -197,7 +197,7 @@
|
||||
<h1>triton.language.dot<a class="headerlink" href="#triton-language-dot" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt class="sig sig-object py" id="triton.language.dot">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">dot</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">other</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">allow_tf32</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.dot" title="Permalink to this definition">¶</a></dt>
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">dot</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">other</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">trans_a</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">trans_b</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">allow_tf32</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.dot" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Returns the matrix product of two blocks.</p>
|
||||
<p>The two blocks must be two dimensionals and have compatible inner dimensions.</p>
|
||||
<dl class="field-list simple">
|
||||
|
@@ -200,7 +200,7 @@
|
||||
<h1>triton.language.store<a class="headerlink" href="#triton-language-store" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt class="sig sig-object py" id="triton.language.store">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">store</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pointer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.store" title="Permalink to this definition">¶</a></dt>
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">store</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pointer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">eviction_policy</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">''</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.store" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Stores <code class="code docutils literal notranslate"><span class="pre">value</span></code> tensor of elements in memory, element-wise, at the memory locations specified by <code class="code docutils literal notranslate"><span class="pre">pointer</span></code>.</p>
|
||||
<p><code class="code docutils literal notranslate"><span class="pre">value</span></code> is implicitly broadcast to <code class="code docutils literal notranslate"><span class="pre">pointer.shape</span></code> and typecast to <code class="code docutils literal notranslate"><span class="pre">pointer.dtype.element_ty</span></code>.</p>
|
||||
<dl class="field-list simple">
|
||||
|