[GH-PAGES] Updated website

This commit is contained in:
Philippe Tillet
2022-07-14 07:22:19 +00:00
parent 3e815114fd
commit d1c6625bfd
179 changed files with 2617 additions and 369 deletions

View File

@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: 255e5b7e27427f0ab6fda308ee6aef63
config: 8d52c5eda79abb41e578ed40b306519c
tags: 645f666f9bcd5a90fca523b33c5a78b7

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,97 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Libdevice function\nTriton can invoke a custom function from an external library.\nIn this example, we will use the `libdevice` library to apply `asin` on a tensor.\nPlease refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.\n\nIn `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.\nFor example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.\nUsing triton, you can simply call `tl.libdevice.asinf`.\ntriton automatically selects the correct underlying device function to invoke based on input and output types.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## asin Kernel\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef asin_kernel(\n x_ptr,\n y_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n x = tl.libdevice.asin(x)\n tl.store(y_ptr + offsets, x, mask=mask)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using the default libdevice library path\nWe can use the default libdevice library path encoded in `triton/language/libdevice.py`\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"torch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\noutput_triton = torch.zeros(size, device='cuda')\noutput_torch = torch.asin(x)\nassert x.is_cuda and output_triton.is_cuda\nn_elements = output_torch.numel()\ngrid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Customize the libdevice library path\nWe can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"output_triton = torch.empty_like(x)\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,\n extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,74 @@
"""
Libdevice function
===============
Triton can invoke a custom function from an external library.
In this example, we will use the `libdevice` library to apply `asin` on a tensor.
Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
Using triton, you can simply call `tl.libdevice.asinf`.
triton automatically selects the correct underlying device function to invoke based on input and output types.
"""
# %%
# asin Kernel
# --------------------------
import torch
import triton
import triton.language as tl
@triton.jit
def asin_kernel(
x_ptr,
y_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
x = tl.libdevice.asin(x)
tl.store(y_ptr + offsets, x, mask=mask)
# %%
# Using the default libdevice library path
# --------------------------
# We can use the default libdevice library path encoded in `triton/language/libdevice.py`
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
output_triton = torch.zeros(size, device='cuda')
output_torch = torch.asin(x)
assert x.is_cuda and output_triton.is_cuda
n_elements = output_torch.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
# %%
# Customize the libdevice library path
# --------------------------
# We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.
output_triton = torch.empty_like(x)
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,
extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)

View File

@@ -0,0 +1,354 @@
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
"""
import pytest
import torch
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
t_ptrs = TMP + off_hz * N_CTX + offs_m
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk *= sm_scale
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vk)
p = p.to(tl.float16)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
denom = tl.load(L + off_m).to(tl.float32)
# compute
do = do / denom[:, None]
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, k, trans_b=True)
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v, trans_b=True)
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
# # compute dq
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds.to(tl.float16), k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
# # increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
BLOCK = 128
# shape constraints
Lq, Lk = q.shape[-1], k.shape[-1]
assert Lq == Lk
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
_fwd_kernel[grid](
q, k, v, sm_scale,
tmp, L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=64, num_warps=4,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = 64
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, l, m = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
)
return dq, dk, dv, None
attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
sm_scale = 0.3
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
tri_out = attention(q, k, v, sm_scale)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
triton.testing.assert_almost_equal(ref_out, tri_out)
triton.testing.assert_almost_equal(ref_dv, tri_dv)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)
try:
from flash_attn.flash_attn_interface import flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 16)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
) for mode in ['bwd']]
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
sm_scale = 1.3
fn = lambda: attention(q, k, v, sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
# only works on A100 at the moment
# bench_flash_attention.run(save_path='.', print_data=True)

View File

@@ -128,17 +128,19 @@ def _layer_norm_bwd_dwdb(
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
UNROLL: tl.constexpr = 4
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
for j in range(UNROLL):
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(DW + cols, sum_dw, mask=cols < N)
@@ -211,7 +213,15 @@ class LayerNorm(torch.autograd.Function):
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
# accumulate partial sums in separate kernel
if N > 10240:
BLOCK_SIZE_N = 128
BLOCK_SIZE_M = 32
num_warps = 4
else:
# maximize occupancy for small N
BLOCK_SIZE_N = 16
BLOCK_SIZE_M = 16
num_warps = 8
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
_layer_norm_bwd_dwdb[grid](
a, dout,
@@ -220,17 +230,11 @@ class LayerNorm(torch.autograd.Function):
dbias,
M,
N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
num_warps=num_warps
)
return (da, None, dweight, dbias, None, None,
None, None, None, None,
None,
None, None, None,
None,
None, None, None,
None, None, None,
None, None, None)
return (da, None, dweight, dbias, None)
def layer_norm(a, normalized_shape, weight, bias, eps):

File diff suppressed because one or more lines are too long

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 37 KiB

After

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 23 KiB

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 59 KiB

After

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 36 KiB

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

View File

@@ -238,7 +238,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
3 32768.0 76.800002 76.800002
4 65536.0 127.999995 127.999995
5 131072.0 219.428568 219.428568
6 262144.0 341.333321 384.000001
6 262144.0 341.333321 341.333321
7 524288.0 472.615390 472.615390
8 1048576.0 614.400016 614.400016
9 2097152.0 722.823517 722.823517
@@ -255,7 +255,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 34.829 seconds)
**Total running time of the script:** ( 1 minutes 50.020 seconds)
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:

View File

@@ -278,17 +278,17 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
softmax-performance:
N Triton Torch (native) Torch (jit)
0 256.0 546.133347 512.000001 188.321838
1 384.0 614.400016 585.142862 153.600004
2 512.0 655.360017 585.142849 154.566038
0 256.0 546.133347 546.133347 186.181817
1 384.0 614.400016 585.142862 151.703707
2 512.0 655.360017 606.814814 154.566038
3 640.0 706.206879 640.000002 160.000000
4 768.0 722.823517 664.216187 162.754967
4 768.0 722.823517 664.216187 163.839992
.. ... ... ... ...
93 12160.0 812.359066 406.179533 198.936606
94 12288.0 812.429770 415.222812 199.298541
95 12416.0 812.498981 412.149375 198.954424
96 12544.0 812.566838 412.758863 199.209928
97 12672.0 811.007961 412.097543 199.264875
94 12288.0 812.429770 415.222812 199.096718
95 12416.0 812.498981 412.149375 198.854847
96 12544.0 810.925276 412.971190 199.012395
97 12672.0 811.007961 412.097543 199.167004
[98 rows x 4 columns]
@@ -306,7 +306,7 @@ In the above plot, we can see that:
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 3 minutes 18.076 seconds)
**Total running time of the script:** ( 3 minutes 32.089 seconds)
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:

View File

@@ -459,37 +459,37 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
matmul-performance:
M cuBLAS ... Triton Triton (+ LeakyReLU)
0 256.0 2.730667 ... 2.978909 2.978909
0 256.0 2.730667 ... 3.276800 2.978909
1 384.0 7.372800 ... 7.899428 7.899428
2 512.0 14.563555 ... 15.420235 15.420235
2 512.0 14.563555 ... 16.384000 15.420235
3 640.0 22.260869 ... 24.380953 24.380953
4 768.0 32.768000 ... 35.389441 34.028308
5 896.0 37.971025 ... 40.140799 39.025776
5 896.0 39.025776 ... 40.140799 39.025776
6 1024.0 49.932191 ... 53.773130 52.428801
7 1152.0 45.242181 ... 48.161033 47.396572
7 1152.0 45.242181 ... 47.396572 47.396572
8 1280.0 51.200001 ... 57.690139 57.690139
9 1408.0 64.138541 ... 68.147202 65.684049
10 1536.0 79.526831 ... 81.355034 78.643199
11 1664.0 63.372618 ... 63.372618 62.492442
9 1408.0 64.138541 ... 68.147202 66.485074
10 1536.0 80.430545 ... 80.430545 78.643199
11 1664.0 62.929456 ... 63.372618 62.492442
12 1792.0 72.983276 ... 72.983276 59.154861
13 1920.0 68.776119 ... 71.626943 70.892307
14 2048.0 73.584279 ... 78.033565 76.959706
15 2176.0 83.155572 ... 87.494120 86.367588
16 2304.0 68.446623 ... 78.064941 77.057651
17 2432.0 71.305746 ... 86.179335 85.393507
18 2560.0 77.833728 ... 82.956960 81.715711
19 2688.0 83.737433 ... 91.185232 89.464755
20 2816.0 82.446516 ... 84.523664 83.712490
21 2944.0 81.967162 ... 83.758038 82.373605
22 3072.0 82.420822 ... 88.750943 86.579673
23 3200.0 81.528664 ... 91.233074 95.665176
24 3328.0 83.516586 ... 85.908470 83.323259
25 3456.0 81.435930 ... 92.138932 90.180725
26 3584.0 83.954614 ... 91.189190 95.858629
27 3712.0 85.822459 ... 83.806497 87.783251
28 3840.0 80.901241 ... 89.259080 89.548180
29 3968.0 87.913500 ... 92.829164 84.096442
30 4096.0 93.825748 ... 89.299883 90.139506
13 1920.0 69.120002 ... 71.257735 71.257735
14 2048.0 73.584279 ... 78.398206 77.314362
15 2176.0 83.155572 ... 87.494120 85.998493
16 2304.0 68.446623 ... 78.320893 77.558029
17 2432.0 71.305746 ... 86.711310 75.421383
18 2560.0 77.833728 ... 82.747477 81.715711
19 2688.0 83.552988 ... 90.532356 89.464755
20 2816.0 84.197315 ... 84.035084 84.035084
21 2944.0 82.784108 ... 83.969728 83.060049
22 3072.0 81.825298 ... 89.593522 88.473602
23 3200.0 84.768213 ... 96.096095 95.808380
24 3328.0 83.226931 ... 85.908470 84.596116
25 3456.0 81.766291 ... 91.824110 91.097818
26 3584.0 87.466332 ... 91.194972 94.847460
27 3712.0 85.822459 ... 87.246590 87.860458
28 3840.0 81.859361 ... 87.011801 90.168771
29 3968.0 89.921841 ... 91.954739 85.271796
30 4096.0 93.596744 ... 88.243079 90.382307
[31 rows x 5 columns]
@@ -499,7 +499,7 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 5 minutes 52.578 seconds)
**Total running time of the script:** ( 7 minutes 13.827 seconds)
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:

View File

@@ -240,7 +240,7 @@ References
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 0.476 seconds)
**Total running time of the script:** ( 0 minutes 0.279 seconds)
.. _sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py:

View File

@@ -21,7 +21,7 @@
Layer Normalization
====================
.. GENERATED FROM PYTHON SOURCE LINES 5-312
.. GENERATED FROM PYTHON SOURCE LINES 5-316
@@ -40,34 +40,34 @@ Layer Normalization
N Triton Torch Apex
0 1024.0 585.142849 277.694907 468.114273
1 1536.0 630.153868 323.368435 511.999982
2 2048.0 682.666643 334.367358 520.126988
3 2560.0 694.237267 365.714281 518.481028
4 3072.0 712.347810 378.092307 501.551037
5 3584.0 725.873439 384.859062 458.751978
6 4096.0 728.177767 381.023256 458.293714
7 4608.0 670.254540 396.387087 426.173427
8 5120.0 694.237267 397.669909 426.666652
9 5632.0 704.000002 396.969169 413.357796
10 6144.0 702.171410 402.885254 411.313806
2 2048.0 668.734716 337.814445 528.516136
3 2560.0 694.237267 362.477870 512.000013
4 3072.0 712.347810 375.206126 501.551037
5 3584.0 725.873439 384.859062 451.527536
6 4096.0 728.177767 381.023256 455.111095
7 4608.0 670.254540 396.387087 421.302872
8 5120.0 688.403381 395.748783 422.268057
9 5632.0 698.542675 396.969169 409.599997
10 6144.0 702.171410 402.885254 409.600010
11 6656.0 700.631610 400.360920 400.360920
12 7168.0 695.078767 396.844306 388.772874
13 7680.0 682.666656 393.846167 387.634072
14 8192.0 642.509816 393.609605 372.363633
15 8704.0 627.315309 389.005597 380.502740
16 9216.0 606.814809 407.337026 383.999986
17 9728.0 589.575753 409.599987 383.369452
18 10240.0 566.920437 408.578556 382.803739
19 10752.0 549.623009 411.559798 381.445676
20 11264.0 536.380957 406.826188 373.134567
21 11776.0 523.377770 410.492372 377.587162
22 12288.0 517.389457 414.784810 383.251457
23 12800.0 505.679014 410.420828 376.470582
24 13312.0 494.180982 405.699062 376.976995
25 13824.0 482.934503 411.888257 379.389355
26 14336.0 471.967074 406.695045 374.185964
27 14848.0 461.297068 408.192434 375.304904
28 15360.0 454.269882 406.214870 378.092307
29 15872.0 447.887117 407.627589 376.225175
12 7168.0 678.627194 386.154893 384.859062
13 7680.0 682.666656 391.337574 386.415087
14 8192.0 645.674867 390.095241 376.643677
15 8704.0 624.502255 390.095225 379.465939
16 9216.0 604.327881 405.098894 383.002605
17 9728.0 585.142883 409.599987 382.427505
18 10240.0 564.965524 409.600010 382.803739
19 10752.0 546.133312 410.577576 380.601764
20 11264.0 531.634232 395.228063 370.069806
21 11776.0 520.486200 409.599991 376.831982
22 12288.0 516.031509 413.911572 383.251457
23 12800.0 504.433489 410.420828 375.779805
24 13312.0 494.180982 405.699062 376.310952
25 13824.0 481.882350 411.888257 378.739711
26 14336.0 471.967074 401.709294 372.969090
27 14848.0 461.297068 407.492270 375.898745
28 15360.0 453.431739 406.887417 378.092307
29 15872.0 447.098578 406.323209 376.225175
@@ -204,17 +204,19 @@ Layer Normalization
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
UNROLL: tl.constexpr = 4
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
for j in range(UNROLL):
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(DW + cols, sum_dw, mask=cols < N)
@@ -287,7 +289,15 @@ Layer Normalization
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
# accumulate partial sums in separate kernel
if N > 10240:
BLOCK_SIZE_N = 128
BLOCK_SIZE_M = 32
num_warps = 4
else:
# maximize occupancy for small N
BLOCK_SIZE_N = 16
BLOCK_SIZE_M = 16
num_warps = 8
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
_layer_norm_bwd_dwdb[grid](
a, dout,
@@ -296,17 +306,11 @@ Layer Normalization
dbias,
M,
N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
num_warps=num_warps
)
return (da, None, dweight, dbias, None, None,
None, None, None, None,
None,
None, None, None,
None,
None, None, None,
None, None, None,
None, None, None)
return (da, None, dweight, dbias, None)
def layer_norm(a, normalized_shape, weight, bias, eps):
@@ -389,7 +393,7 @@ Layer Normalization
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 5 minutes 24.641 seconds)
**Total running time of the script:** ( 5 minutes 32.552 seconds)
.. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py:

View File

@@ -0,0 +1,416 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/06-fused-attention.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_06-fused-attention.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_06-fused-attention.py:
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
.. GENERATED FROM PYTHON SOURCE LINES 7-355
.. code-block:: default
import pytest
import torch
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
t_ptrs = TMP + off_hz * N_CTX + offs_m
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk *= sm_scale
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vk)
p = p.to(tl.float16)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
denom = tl.load(L + off_m).to(tl.float32)
# compute
do = do / denom[:, None]
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, k, trans_b=True)
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v, trans_b=True)
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
# # compute dq
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds.to(tl.float16), k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
# # increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
BLOCK = 128
# shape constraints
Lq, Lk = q.shape[-1], k.shape[-1]
assert Lq == Lk
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
_fwd_kernel[grid](
q, k, v, sm_scale,
tmp, L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=64, num_warps=4,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = 64
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, l, m = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
)
return dq, dk, dv, None
attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
sm_scale = 0.3
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
tri_out = attention(q, k, v, sm_scale)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
triton.testing.assert_almost_equal(ref_out, tri_out)
triton.testing.assert_almost_equal(ref_dv, tri_dv)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)
try:
from flash_attn.flash_attn_interface import flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 16)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
) for mode in ['bwd']]
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
sm_scale = 1.3
fn = lambda: attention(q, k, v, sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
# only works on A100 at the moment
# bench_flash_attention.run(save_path='.', print_data=True)
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 0.072 seconds)
.. _sphx_glr_download_getting-started_tutorials_06-fused-attention.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 06-fused-attention.py <06-fused-attention.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 06-fused-attention.ipynb <06-fused-attention.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -0,0 +1,183 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/07-libdevice-function.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_07-libdevice-function.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_07-libdevice-function.py:
Libdevice function
===============
Triton can invoke a custom function from an external library.
In this example, we will use the `libdevice` library to apply `asin` on a tensor.
Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
Using triton, you can simply call `tl.libdevice.asinf`.
triton automatically selects the correct underlying device function to invoke based on input and output types.
.. GENERATED FROM PYTHON SOURCE LINES 15-17
asin Kernel
--------------------------
.. GENERATED FROM PYTHON SOURCE LINES 17-39
.. code-block:: default
import torch
import triton
import triton.language as tl
@triton.jit
def asin_kernel(
x_ptr,
y_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
x = tl.libdevice.asin(x)
tl.store(y_ptr + offsets, x, mask=mask)
.. GENERATED FROM PYTHON SOURCE LINES 40-43
Using the default libdevice library path
--------------------------
We can use the default libdevice library path encoded in `triton/language/libdevice.py`
.. GENERATED FROM PYTHON SOURCE LINES 43-61
.. code-block:: default
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
output_triton = torch.zeros(size, device='cuda')
output_torch = torch.asin(x)
assert x.is_cuda and output_triton.is_cuda
n_elements = output_torch.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
The maximum difference between torch and triton is 2.384185791015625e-07
.. GENERATED FROM PYTHON SOURCE LINES 62-65
Customize the libdevice library path
--------------------------
We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.
.. GENERATED FROM PYTHON SOURCE LINES 65-75
.. code-block:: default
output_triton = torch.empty_like(x)
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,
extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
The maximum difference between torch and triton is 2.384185791015625e-07
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 0.501 seconds)
.. _sphx_glr_download_getting-started_tutorials_07-libdevice-function.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 07-libdevice-function.py <07-libdevice-function.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 07-libdevice-function.ipynb <07-libdevice-function.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -122,6 +122,48 @@ To install the dependencies for the tutorials:
:hidden:
/getting-started/tutorials/05-layer-norm
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Fused Attention">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_06-fused-attention_thumb.png
:alt: Fused Attention
:ref:`sphx_glr_getting-started_tutorials_06-fused-attention.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/06-fused-attention
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="In trition/language/libdevice.py, we try to aggregate functions with the same computation but d...">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_07-libdevice-function_thumb.png
:alt: Libdevice function
:ref:`sphx_glr_getting-started_tutorials_07-libdevice-function.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/07-libdevice-function
.. raw:: html
<div class="sphx-glr-clear"></div>

View File

@@ -5,16 +5,20 @@
Computation times
=================
**16:10.599** total execution time for **getting-started_tutorials** files:
**18:09.339** total execution time for **getting-started_tutorials** files:
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 05:52.578 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 07:13.827 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 05:24.641 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 05:32.552 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:18.076 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:32.089 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:34.829 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:50.020 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.476 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_07-libdevice-function.py` (``07-libdevice-function.py``) | 00:00.501 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.279 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_06-fused-attention.py` (``06-fused-attention.py``) | 00:00.072 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+

View File

@@ -105,6 +105,8 @@
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
</ul>
</li>
</ul>
@@ -328,7 +330,7 @@ for different problem sizes.</p>
3 32768.0 76.800002 76.800002
4 65536.0 127.999995 127.999995
5 131072.0 219.428568 219.428568
6 262144.0 341.333321 384.000001
6 262144.0 341.333321 341.333321
7 524288.0 472.615390 472.615390
8 1048576.0 614.400016 614.400016
9 2097152.0 722.823517 722.823517
@@ -340,7 +342,7 @@ for different problem sizes.</p>
15 134217728.0 849.737435 850.656574
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 34.829 seconds)</p>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 50.020 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>

View File

@@ -108,6 +108,8 @@
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
</ul>
</li>
</ul>
@@ -369,17 +371,17 @@ We will then compare its performance against (1) <code class="code docutils lite
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>softmax-performance:
N Triton Torch (native) Torch (jit)
0 256.0 546.133347 512.000001 188.321838
1 384.0 614.400016 585.142862 153.600004
2 512.0 655.360017 585.142849 154.566038
0 256.0 546.133347 546.133347 186.181817
1 384.0 614.400016 585.142862 151.703707
2 512.0 655.360017 606.814814 154.566038
3 640.0 706.206879 640.000002 160.000000
4 768.0 722.823517 664.216187 162.754967
4 768.0 722.823517 664.216187 163.839992
.. ... ... ... ...
93 12160.0 812.359066 406.179533 198.936606
94 12288.0 812.429770 415.222812 199.298541
95 12416.0 812.498981 412.149375 198.954424
96 12544.0 812.566838 412.758863 199.209928
97 12672.0 811.007961 412.097543 199.264875
94 12288.0 812.429770 415.222812 199.096718
95 12416.0 812.498981 412.149375 198.854847
96 12544.0 810.925276 412.971190 199.012395
97 12672.0 811.007961 412.097543 199.167004
[98 rows x 4 columns]
</pre></div>
@@ -392,7 +394,7 @@ We will then compare its performance against (1) <code class="code docutils lite
Note however that the PyTorch <cite>softmax</cite> operation is more general and will works on tensors of any shape.</p></li>
</ul>
</div></blockquote>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 3 minutes 18.076 seconds)</p>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 3 minutes 32.089 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>

View File

@@ -115,6 +115,8 @@
</li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
</ul>
</li>
</ul>
@@ -565,42 +567,42 @@ torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -3
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>matmul-performance:
M cuBLAS ... Triton Triton (+ LeakyReLU)
0 256.0 2.730667 ... 2.978909 2.978909
0 256.0 2.730667 ... 3.276800 2.978909
1 384.0 7.372800 ... 7.899428 7.899428
2 512.0 14.563555 ... 15.420235 15.420235
2 512.0 14.563555 ... 16.384000 15.420235
3 640.0 22.260869 ... 24.380953 24.380953
4 768.0 32.768000 ... 35.389441 34.028308
5 896.0 37.971025 ... 40.140799 39.025776
5 896.0 39.025776 ... 40.140799 39.025776
6 1024.0 49.932191 ... 53.773130 52.428801
7 1152.0 45.242181 ... 48.161033 47.396572
7 1152.0 45.242181 ... 47.396572 47.396572
8 1280.0 51.200001 ... 57.690139 57.690139
9 1408.0 64.138541 ... 68.147202 65.684049
10 1536.0 79.526831 ... 81.355034 78.643199
11 1664.0 63.372618 ... 63.372618 62.492442
9 1408.0 64.138541 ... 68.147202 66.485074
10 1536.0 80.430545 ... 80.430545 78.643199
11 1664.0 62.929456 ... 63.372618 62.492442
12 1792.0 72.983276 ... 72.983276 59.154861
13 1920.0 68.776119 ... 71.626943 70.892307
14 2048.0 73.584279 ... 78.033565 76.959706
15 2176.0 83.155572 ... 87.494120 86.367588
16 2304.0 68.446623 ... 78.064941 77.057651
17 2432.0 71.305746 ... 86.179335 85.393507
18 2560.0 77.833728 ... 82.956960 81.715711
19 2688.0 83.737433 ... 91.185232 89.464755
20 2816.0 82.446516 ... 84.523664 83.712490
21 2944.0 81.967162 ... 83.758038 82.373605
22 3072.0 82.420822 ... 88.750943 86.579673
23 3200.0 81.528664 ... 91.233074 95.665176
24 3328.0 83.516586 ... 85.908470 83.323259
25 3456.0 81.435930 ... 92.138932 90.180725
26 3584.0 83.954614 ... 91.189190 95.858629
27 3712.0 85.822459 ... 83.806497 87.783251
28 3840.0 80.901241 ... 89.259080 89.548180
29 3968.0 87.913500 ... 92.829164 84.096442
30 4096.0 93.825748 ... 89.299883 90.139506
13 1920.0 69.120002 ... 71.257735 71.257735
14 2048.0 73.584279 ... 78.398206 77.314362
15 2176.0 83.155572 ... 87.494120 85.998493
16 2304.0 68.446623 ... 78.320893 77.558029
17 2432.0 71.305746 ... 86.711310 75.421383
18 2560.0 77.833728 ... 82.747477 81.715711
19 2688.0 83.552988 ... 90.532356 89.464755
20 2816.0 84.197315 ... 84.035084 84.035084
21 2944.0 82.784108 ... 83.969728 83.060049
22 3072.0 81.825298 ... 89.593522 88.473602
23 3200.0 84.768213 ... 96.096095 95.808380
24 3328.0 83.226931 ... 85.908470 84.596116
25 3456.0 81.766291 ... 91.824110 91.097818
26 3584.0 87.466332 ... 91.194972 94.847460
27 3712.0 85.822459 ... 87.246590 87.860458
28 3840.0 81.859361 ... 87.011801 90.168771
29 3968.0 89.921841 ... 91.954739 85.271796
30 4096.0 93.596744 ... 88.243079 90.382307
[31 rows x 5 columns]
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 5 minutes 52.578 seconds)</p>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 7 minutes 13.827 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>

View File

@@ -108,6 +108,8 @@
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
</ul>
</li>
</ul>
@@ -372,7 +374,7 @@ to explore the <cite>triton/language/random</cite> folder!</p>
<dd><p>Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, JMLR 2014</p>
</dd>
</dl>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 0.476 seconds)</p>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 0.279 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/c9aed78977a4c05741d675a38dde3d7d/04-low-memory-dropout.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">04-low-memory-dropout.py</span></code></a></p>

View File

@@ -46,7 +46,7 @@
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="triton" href="../../python-api/triton.html" />
<link rel="next" title="Fused Attention" href="06-fused-attention.html" />
<link rel="prev" title="Low-Memory Dropout" href="04-low-memory-dropout.html" />
</head>
@@ -101,6 +101,8 @@
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Layer Normalization</a></li>
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
</ul>
</li>
</ul>
@@ -196,34 +198,34 @@ to download the full example code</p>
N Triton Torch Apex
0 1024.0 585.142849 277.694907 468.114273
1 1536.0 630.153868 323.368435 511.999982
2 2048.0 682.666643 334.367358 520.126988
3 2560.0 694.237267 365.714281 518.481028
4 3072.0 712.347810 378.092307 501.551037
5 3584.0 725.873439 384.859062 458.751978
6 4096.0 728.177767 381.023256 458.293714
7 4608.0 670.254540 396.387087 426.173427
8 5120.0 694.237267 397.669909 426.666652
9 5632.0 704.000002 396.969169 413.357796
10 6144.0 702.171410 402.885254 411.313806
2 2048.0 668.734716 337.814445 528.516136
3 2560.0 694.237267 362.477870 512.000013
4 3072.0 712.347810 375.206126 501.551037
5 3584.0 725.873439 384.859062 451.527536
6 4096.0 728.177767 381.023256 455.111095
7 4608.0 670.254540 396.387087 421.302872
8 5120.0 688.403381 395.748783 422.268057
9 5632.0 698.542675 396.969169 409.599997
10 6144.0 702.171410 402.885254 409.600010
11 6656.0 700.631610 400.360920 400.360920
12 7168.0 695.078767 396.844306 388.772874
13 7680.0 682.666656 393.846167 387.634072
14 8192.0 642.509816 393.609605 372.363633
15 8704.0 627.315309 389.005597 380.502740
16 9216.0 606.814809 407.337026 383.999986
17 9728.0 589.575753 409.599987 383.369452
18 10240.0 566.920437 408.578556 382.803739
19 10752.0 549.623009 411.559798 381.445676
20 11264.0 536.380957 406.826188 373.134567
21 11776.0 523.377770 410.492372 377.587162
22 12288.0 517.389457 414.784810 383.251457
23 12800.0 505.679014 410.420828 376.470582
24 13312.0 494.180982 405.699062 376.976995
25 13824.0 482.934503 411.888257 379.389355
26 14336.0 471.967074 406.695045 374.185964
27 14848.0 461.297068 408.192434 375.304904
28 15360.0 454.269882 406.214870 378.092307
29 15872.0 447.887117 407.627589 376.225175
12 7168.0 678.627194 386.154893 384.859062
13 7680.0 682.666656 391.337574 386.415087
14 8192.0 645.674867 390.095241 376.643677
15 8704.0 624.502255 390.095225 379.465939
16 9216.0 604.327881 405.098894 383.002605
17 9728.0 585.142883 409.599987 382.427505
18 10240.0 564.965524 409.600010 382.803739
19 10752.0 546.133312 410.577576 380.601764
20 11264.0 531.634232 395.228063 370.069806
21 11776.0 520.486200 409.599991 376.831982
22 12288.0 516.031509 413.911572 383.251457
23 12800.0 504.433489 410.420828 375.779805
24 13312.0 494.180982 405.699062 376.310952
25 13824.0 481.882350 411.888257 378.739711
26 14336.0 471.967074 401.709294 372.969090
27 14848.0 461.297068 407.492270 375.898745
28 15360.0 453.431739 406.887417 378.092307
29 15872.0 447.098578 406.323209 376.225175
</pre></div>
</div>
<div class="line-block">
@@ -354,17 +356,19 @@ to download the full example code</p>
<span class="n">cols</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
<span class="n">dw</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">db</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">):</span>
<span class="n">rows</span> <span class="o">=</span> <span class="n">i</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span>
<span class="n">offs</span> <span class="o">=</span> <span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">A</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">dout</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DOut</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">mean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Mean</span> <span class="o">+</span> <span class="n">rows</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">rows</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
<span class="n">rstd</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Var</span> <span class="o">+</span> <span class="n">rows</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">rows</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
<span class="n">a_hat</span> <span class="o">=</span> <span class="p">(</span><span class="n">a</span> <span class="o">-</span> <span class="n">mean</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span> <span class="o">*</span> <span class="n">rstd</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">dw</span> <span class="o">+=</span> <span class="n">dout</span> <span class="o">*</span> <span class="n">a_hat</span>
<span class="n">db</span> <span class="o">+=</span> <span class="n">dout</span>
<span class="n">UNROLL</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span> <span class="o">=</span> <span class="mi">4</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">*</span> <span class="n">UNROLL</span><span class="p">):</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">UNROLL</span><span class="p">):</span>
<span class="n">rows</span> <span class="o">=</span> <span class="n">i</span> <span class="o">+</span> <span class="n">j</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span>
<span class="n">offs</span> <span class="o">=</span> <span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">A</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">dout</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DOut</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">mean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Mean</span> <span class="o">+</span> <span class="n">rows</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">rows</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
<span class="n">rstd</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Var</span> <span class="o">+</span> <span class="n">rows</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">rows</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
<span class="n">a_hat</span> <span class="o">=</span> <span class="p">(</span><span class="n">a</span> <span class="o">-</span> <span class="n">mean</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span> <span class="o">*</span> <span class="n">rstd</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">dw</span> <span class="o">+=</span> <span class="n">dout</span> <span class="o">*</span> <span class="n">a_hat</span>
<span class="n">db</span> <span class="o">+=</span> <span class="n">dout</span>
<span class="n">sum_dw</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dw</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">sum_db</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">db</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">DW</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">sum_dw</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">cols</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span>
@@ -437,7 +441,15 @@ to download the full example code</p>
<span class="n">BLOCK_SIZE_N</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_SIZE</span><span class="p">,</span>
<span class="n">num_warps</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">num_warps</span><span class="p">,</span>
<span class="p">)</span>
<span class="c1"># accumulate partial sums in separate kernel</span>
<span class="k">if</span> <span class="n">N</span> <span class="o">&gt;</span> <span class="mi">10240</span><span class="p">:</span>
<span class="n">BLOCK_SIZE_N</span> <span class="o">=</span> <span class="mi">128</span>
<span class="n">BLOCK_SIZE_M</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">4</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># maximize occupancy for small N</span>
<span class="n">BLOCK_SIZE_N</span> <span class="o">=</span> <span class="mi">16</span>
<span class="n">BLOCK_SIZE_M</span> <span class="o">=</span> <span class="mi">16</span>
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">8</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">[</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s2">&quot;BLOCK_SIZE_N&quot;</span><span class="p">])]</span>
<span class="n">_layer_norm_bwd_dwdb</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
<span class="n">a</span><span class="p">,</span> <span class="n">dout</span><span class="p">,</span>
@@ -446,17 +458,11 @@ to download the full example code</p>
<span class="n">dbias</span><span class="p">,</span>
<span class="n">M</span><span class="p">,</span>
<span class="n">N</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_M</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_N</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_M</span><span class="o">=</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_N</span><span class="o">=</span><span class="n">BLOCK_SIZE_N</span><span class="p">,</span>
<span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span>
<span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">da</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dweight</span><span class="p">,</span> <span class="n">dbias</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span>
<span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span>
<span class="kc">None</span><span class="p">,</span>
<span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span>
<span class="kc">None</span><span class="p">,</span>
<span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span>
<span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span>
<span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">da</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dweight</span><span class="p">,</span> <span class="n">dbias</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">layer_norm</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">normalized_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">):</span>
@@ -537,7 +543,7 @@ to download the full example code</p>
<span class="n">bench_layer_norm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">save_path</span><span class="o">=</span><span class="s1">&#39;.&#39;</span><span class="p">,</span> <span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 5 minutes 24.641 seconds)</p>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 5 minutes 32.552 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-05-layer-norm-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">05-layer-norm.py</span></code></a></p>
@@ -555,7 +561,7 @@ to download the full example code</p>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../../python-api/triton.html" class="btn btn-neutral float-right" title="triton" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="06-fused-attention.html" class="btn btn-neutral float-right" title="Fused Attention" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="04-low-memory-dropout.html" class="btn btn-neutral float-left" title="Low-Memory Dropout" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>

View File

@@ -0,0 +1,623 @@
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Fused Attention &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Libdevice function" href="07-libdevice-function.html" />
<link rel="prev" title="Layer Normalization" href="05-layer-norm.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Fused Attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Fused Attention</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/06-fused-attention.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-06-fused-attention-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="fused-attention">
<span id="sphx-glr-getting-started-tutorials-06-fused-attention-py"></span><h1>Fused Attention<a class="headerlink" href="#fused-attention" title="Permalink to this headline"></a></h1>
<p>This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., <a class="reference external" href="https://arxiv.org/pdf/2205.14135v2.pdf">https://arxiv.org/pdf/2205.14135v2.pdf</a>; Rabe and Staats <a class="reference external" href="https://arxiv.org/pdf/2112.05682v2.pdf">https://arxiv.org/pdf/2112.05682v2.pdf</a>)</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">pytest</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_fwd_kernel</span><span class="p">(</span>
<span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">,</span>
<span class="n">TMP</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="c1"># NOTE: TMP is a scratchpad buffer to workaround a compiler bug</span>
<span class="n">Out</span><span class="p">,</span>
<span class="n">stride_qz</span><span class="p">,</span> <span class="n">stride_qh</span><span class="p">,</span> <span class="n">stride_qm</span><span class="p">,</span> <span class="n">stride_qk</span><span class="p">,</span>
<span class="n">stride_kz</span><span class="p">,</span> <span class="n">stride_kh</span><span class="p">,</span> <span class="n">stride_kn</span><span class="p">,</span> <span class="n">stride_kk</span><span class="p">,</span>
<span class="n">stride_vz</span><span class="p">,</span> <span class="n">stride_vh</span><span class="p">,</span> <span class="n">stride_vk</span><span class="p">,</span> <span class="n">stride_vn</span><span class="p">,</span>
<span class="n">stride_oz</span><span class="p">,</span> <span class="n">stride_oh</span><span class="p">,</span> <span class="n">stride_om</span><span class="p">,</span> <span class="n">stride_on</span><span class="p">,</span>
<span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span>
<span class="n">BLOCK_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="n">BLOCK_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">start_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">off_hz</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># initialize offsets</span>
<span class="n">offs_m</span> <span class="o">=</span> <span class="n">start_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
<span class="n">offs_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
<span class="n">offs_d</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">)</span>
<span class="n">off_q</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">stride_qh</span> <span class="o">+</span> <span class="n">offs_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_d</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span>
<span class="n">off_k</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">stride_qh</span> <span class="o">+</span> <span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_kn</span> <span class="o">+</span> <span class="n">offs_d</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_kk</span>
<span class="n">off_v</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">stride_qh</span> <span class="o">+</span> <span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_d</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span>
<span class="c1"># Initialize pointers to Q, K, V</span>
<span class="n">q_ptrs</span> <span class="o">=</span> <span class="n">Q</span> <span class="o">+</span> <span class="n">off_q</span>
<span class="n">k_ptrs</span> <span class="o">=</span> <span class="n">K</span> <span class="o">+</span> <span class="n">off_k</span>
<span class="n">v_ptrs</span> <span class="o">=</span> <span class="n">V</span> <span class="o">+</span> <span class="n">off_v</span>
<span class="c1"># initialize pointer to m and l</span>
<span class="n">t_ptrs</span> <span class="o">=</span> <span class="n">TMP</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span> <span class="o">+</span> <span class="n">offs_m</span>
<span class="n">m_i</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">-</span> <span class="nb">float</span><span class="p">(</span><span class="s2">&quot;inf&quot;</span><span class="p">)</span>
<span class="n">l_i</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">acc</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="c1"># load q: it will stay in SRAM throughout</span>
<span class="n">q</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">q_ptrs</span><span class="p">)</span>
<span class="c1"># loop over k, v and update accumulator</span>
<span class="k">for</span> <span class="n">start_n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="p">(</span><span class="n">start_m</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">):</span>
<span class="n">start_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">multiple_of</span><span class="p">(</span><span class="n">start_n</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
<span class="c1"># -- compute qk ----</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">k_ptrs</span> <span class="o">+</span> <span class="n">start_n</span> <span class="o">*</span> <span class="n">stride_kn</span><span class="p">)</span>
<span class="n">qk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">qk</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">trans_b</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">qk</span> <span class="o">*=</span> <span class="n">sm_scale</span>
<span class="n">qk</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">offs_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">&gt;=</span> <span class="p">(</span><span class="n">start_n</span> <span class="o">+</span> <span class="n">offs_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]),</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s2">&quot;-inf&quot;</span><span class="p">))</span>
<span class="c1"># -- compute m_ij, p, l_ij</span>
<span class="n">m_ij</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">qk</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">qk</span> <span class="o">-</span> <span class="n">m_ij</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span>
<span class="n">l_ij</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="c1"># -- update m_i and l_i</span>
<span class="n">m_i_new</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">m_i</span><span class="p">,</span> <span class="n">m_ij</span><span class="p">)</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">m_i</span> <span class="o">-</span> <span class="n">m_i_new</span><span class="p">)</span>
<span class="n">beta</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">m_ij</span> <span class="o">-</span> <span class="n">m_i_new</span><span class="p">)</span>
<span class="n">l_i_new</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">l_i</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">l_ij</span>
<span class="c1"># -- update output accumulator --</span>
<span class="c1"># scale p</span>
<span class="n">p_scale</span> <span class="o">=</span> <span class="n">beta</span> <span class="o">/</span> <span class="n">l_i_new</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">p</span> <span class="o">*</span> <span class="n">p_scale</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="c1"># scale acc</span>
<span class="n">acc_scale</span> <span class="o">=</span> <span class="n">l_i</span> <span class="o">/</span> <span class="n">l_i_new</span> <span class="o">*</span> <span class="n">alpha</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">t_ptrs</span><span class="p">,</span> <span class="n">acc_scale</span><span class="p">)</span>
<span class="n">acc_scale</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">t_ptrs</span><span class="p">)</span> <span class="c1"># BUG: have to store and immediately load</span>
<span class="n">acc</span> <span class="o">=</span> <span class="n">acc</span> <span class="o">*</span> <span class="n">acc_scale</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="c1"># update acc</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">v_ptrs</span> <span class="o">+</span> <span class="n">start_n</span> <span class="o">*</span> <span class="n">stride_vk</span><span class="p">)</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="n">acc</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="c1"># update m_i and l_i</span>
<span class="n">l_i</span> <span class="o">=</span> <span class="n">l_i_new</span>
<span class="n">m_i</span> <span class="o">=</span> <span class="n">m_i_new</span>
<span class="c1"># rematerialize offsets to save registers</span>
<span class="n">start_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">offs_m</span> <span class="o">=</span> <span class="n">start_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
<span class="c1"># write back l and m</span>
<span class="n">l_ptrs</span> <span class="o">=</span> <span class="n">L</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span> <span class="o">+</span> <span class="n">offs_m</span>
<span class="n">m_ptrs</span> <span class="o">=</span> <span class="n">M</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span> <span class="o">+</span> <span class="n">offs_m</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">l_ptrs</span><span class="p">,</span> <span class="n">l_i</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">m_ptrs</span><span class="p">,</span> <span class="n">m_i</span><span class="p">)</span>
<span class="c1"># initialize pointers to output</span>
<span class="n">offs_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">)</span>
<span class="n">off_o</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">stride_oh</span> <span class="o">+</span> <span class="n">offs_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_om</span> <span class="o">+</span> <span class="n">offs_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_on</span>
<span class="n">out_ptrs</span> <span class="o">=</span> <span class="n">Out</span> <span class="o">+</span> <span class="n">off_o</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">out_ptrs</span><span class="p">,</span> <span class="n">acc</span><span class="p">)</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_bwd_preprocess</span><span class="p">(</span>
<span class="n">Out</span><span class="p">,</span> <span class="n">DO</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span>
<span class="n">NewDO</span><span class="p">,</span> <span class="n">Delta</span><span class="p">,</span>
<span class="n">BLOCK_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">off_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
<span class="n">off_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">)</span>
<span class="c1"># load</span>
<span class="n">o</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Out</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">D_HEAD</span> <span class="o">+</span> <span class="n">off_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">do</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DO</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">D_HEAD</span> <span class="o">+</span> <span class="n">off_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">denom</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">L</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="c1"># compute</span>
<span class="n">do</span> <span class="o">=</span> <span class="n">do</span> <span class="o">/</span> <span class="n">denom</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">delta</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">o</span> <span class="o">*</span> <span class="n">do</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># write-back</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">NewDO</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">D_HEAD</span> <span class="o">+</span> <span class="n">off_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:],</span> <span class="n">do</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Delta</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">,</span> <span class="n">delta</span><span class="p">)</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_bwd_kernel</span><span class="p">(</span>
<span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">,</span> <span class="n">Out</span><span class="p">,</span> <span class="n">DO</span><span class="p">,</span>
<span class="n">DQ</span><span class="p">,</span> <span class="n">DK</span><span class="p">,</span> <span class="n">DV</span><span class="p">,</span>
<span class="n">L</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span>
<span class="n">D</span><span class="p">,</span>
<span class="n">stride_qz</span><span class="p">,</span> <span class="n">stride_qh</span><span class="p">,</span> <span class="n">stride_qm</span><span class="p">,</span> <span class="n">stride_qk</span><span class="p">,</span>
<span class="n">stride_kz</span><span class="p">,</span> <span class="n">stride_kh</span><span class="p">,</span> <span class="n">stride_kn</span><span class="p">,</span> <span class="n">stride_kk</span><span class="p">,</span>
<span class="n">stride_vz</span><span class="p">,</span> <span class="n">stride_vh</span><span class="p">,</span> <span class="n">stride_vk</span><span class="p">,</span> <span class="n">stride_vn</span><span class="p">,</span>
<span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span>
<span class="n">num_block</span><span class="p">,</span>
<span class="n">BLOCK_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="n">BLOCK_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">off_hz</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">off_z</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">//</span> <span class="n">H</span>
<span class="n">off_h</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">%</span> <span class="n">H</span>
<span class="c1"># offset pointers for batch/head</span>
<span class="n">Q</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
<span class="n">K</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
<span class="n">V</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
<span class="n">DO</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
<span class="n">DQ</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
<span class="n">DK</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
<span class="n">DV</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
<span class="k">for</span> <span class="n">start_n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num_block</span><span class="p">):</span>
<span class="n">lo</span> <span class="o">=</span> <span class="n">start_n</span> <span class="o">*</span> <span class="n">BLOCK_M</span>
<span class="c1"># initialize row/col offsets</span>
<span class="n">offs_qm</span> <span class="o">=</span> <span class="n">lo</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
<span class="n">offs_n</span> <span class="o">=</span> <span class="n">start_n</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
<span class="n">offs_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
<span class="n">offs_k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">)</span>
<span class="c1"># initialize pointers to value-like data</span>
<span class="n">q_ptrs</span> <span class="o">=</span> <span class="n">Q</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_qm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
<span class="n">k_ptrs</span> <span class="o">=</span> <span class="n">K</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_kn</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_kk</span><span class="p">)</span>
<span class="n">v_ptrs</span> <span class="o">=</span> <span class="n">V</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
<span class="n">do_ptrs</span> <span class="o">=</span> <span class="n">DO</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_qm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
<span class="n">dq_ptrs</span> <span class="o">=</span> <span class="n">DQ</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_qm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
<span class="c1"># pointer to row-wise quantities in value-like data</span>
<span class="n">D_ptrs</span> <span class="o">=</span> <span class="n">D</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span>
<span class="n">m_ptrs</span> <span class="o">=</span> <span class="n">M</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span>
<span class="c1"># initialize dv amd dk</span>
<span class="n">dv</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">dk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="c1"># k and v stay in SRAM throughout</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">k_ptrs</span><span class="p">)</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">v_ptrs</span><span class="p">)</span>
<span class="c1"># loop over rows</span>
<span class="k">for</span> <span class="n">start_m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">lo</span><span class="p">,</span> <span class="n">num_block</span> <span class="o">*</span> <span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">):</span>
<span class="n">offs_m_curr</span> <span class="o">=</span> <span class="n">start_m</span> <span class="o">+</span> <span class="n">offs_m</span>
<span class="c1"># load q, k, v, do on-chip</span>
<span class="n">q</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">q_ptrs</span><span class="p">)</span>
<span class="c1"># recompute p = softmax(qk, dim=-1).T</span>
<span class="c1"># NOTE: `do` is pre-divided by `l`; no normalization here</span>
<span class="n">qk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">trans_b</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">qk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">offs_m_curr</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">&gt;=</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]),</span> <span class="n">qk</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s2">&quot;-inf&quot;</span><span class="p">))</span>
<span class="n">m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">m_ptrs</span> <span class="o">+</span> <span class="n">offs_m_curr</span><span class="p">)</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">qk</span> <span class="o">*</span> <span class="n">sm_scale</span> <span class="o">-</span> <span class="n">m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span>
<span class="c1"># compute dv</span>
<span class="n">do</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">do_ptrs</span><span class="p">)</span>
<span class="n">dv</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">),</span> <span class="n">do</span><span class="p">,</span> <span class="n">trans_a</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># compute dp = dot(v, do)</span>
<span class="n">Di</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">D_ptrs</span> <span class="o">+</span> <span class="n">offs_m_curr</span><span class="p">)</span>
<span class="n">dp</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">-</span> <span class="n">Di</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">dp</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">do</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">trans_b</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># compute ds = p * (dp - delta[:, None])</span>
<span class="n">ds</span> <span class="o">=</span> <span class="n">p</span> <span class="o">*</span> <span class="n">dp</span> <span class="o">*</span> <span class="n">sm_scale</span>
<span class="c1"># compute dk = dot(ds.T, q)</span>
<span class="n">dk</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">ds</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">),</span> <span class="n">q</span><span class="p">,</span> <span class="n">trans_a</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># # compute dq</span>
<span class="n">dq</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">dq_ptrs</span><span class="p">,</span> <span class="n">eviction_policy</span><span class="o">=</span><span class="s2">&quot;evict_last&quot;</span><span class="p">)</span>
<span class="n">dq</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">ds</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">),</span> <span class="n">k</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">dq_ptrs</span><span class="p">,</span> <span class="n">dq</span><span class="p">,</span> <span class="n">eviction_policy</span><span class="o">=</span><span class="s2">&quot;evict_last&quot;</span><span class="p">)</span>
<span class="c1"># # increment pointers</span>
<span class="n">dq_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_M</span> <span class="o">*</span> <span class="n">stride_qm</span>
<span class="n">q_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_M</span> <span class="o">*</span> <span class="n">stride_qm</span>
<span class="n">do_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_M</span> <span class="o">*</span> <span class="n">stride_qm</span>
<span class="c1"># write-back</span>
<span class="n">dv_ptrs</span> <span class="o">=</span> <span class="n">DV</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
<span class="n">dk_ptrs</span> <span class="o">=</span> <span class="n">DK</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_kn</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_kk</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">dv_ptrs</span><span class="p">,</span> <span class="n">dv</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">dk_ptrs</span><span class="p">,</span> <span class="n">dk</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">_attention</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">):</span>
<span class="n">BLOCK</span> <span class="o">=</span> <span class="mi">128</span>
<span class="c1"># shape constraints</span>
<span class="n">Lq</span><span class="p">,</span> <span class="n">Lk</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="k">assert</span> <span class="n">Lq</span> <span class="o">==</span> <span class="n">Lk</span>
<span class="n">o</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">q</span><span class="p">)</span>
<span class="n">grid</span> <span class="o">=</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">BLOCK</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="n">tmp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="n">device</span><span class="o">=</span><span class="n">q</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">L</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="n">device</span><span class="o">=</span><span class="n">q</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">m</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="n">device</span><span class="o">=</span><span class="n">q</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">_fwd_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
<span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">,</span>
<span class="n">tmp</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span>
<span class="n">o</span><span class="p">,</span>
<span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
<span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
<span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
<span class="n">o</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">o</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">o</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">o</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
<span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span>
<span class="n">BLOCK_M</span><span class="o">=</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="o">=</span><span class="n">BLOCK</span><span class="p">,</span>
<span class="n">BLOCK_DMODEL</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span>
<span class="n">num_stages</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">o</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">m</span><span class="p">)</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK</span> <span class="o">=</span> <span class="n">BLOCK</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">grid</span> <span class="o">=</span> <span class="n">grid</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">sm_scale</span> <span class="o">=</span> <span class="n">sm_scale</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_DMODEL</span> <span class="o">=</span> <span class="mi">64</span>
<span class="k">return</span> <span class="n">o</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">do</span><span class="p">):</span>
<span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">o</span><span class="p">,</span> <span class="n">l</span><span class="p">,</span> <span class="n">m</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span>
<span class="n">do</span> <span class="o">=</span> <span class="n">do</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">dq</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">dk</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
<span class="n">dv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
<span class="n">do_scaled</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">do</span><span class="p">)</span>
<span class="n">delta</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">l</span><span class="p">)</span>
<span class="n">_bwd_preprocess</span><span class="p">[(</span><span class="n">ctx</span><span class="o">.</span><span class="n">grid</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">ctx</span><span class="o">.</span><span class="n">grid</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="p">)](</span>
<span class="n">o</span><span class="p">,</span> <span class="n">do</span><span class="p">,</span> <span class="n">l</span><span class="p">,</span>
<span class="n">do_scaled</span><span class="p">,</span> <span class="n">delta</span><span class="p">,</span>
<span class="n">BLOCK_M</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_DMODEL</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">_bwd_kernel</span><span class="p">[(</span><span class="n">ctx</span><span class="o">.</span><span class="n">grid</span><span class="p">[</span><span class="mi">1</span><span class="p">],)](</span>
<span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">ctx</span><span class="o">.</span><span class="n">sm_scale</span><span class="p">,</span>
<span class="n">o</span><span class="p">,</span> <span class="n">do_scaled</span><span class="p">,</span>
<span class="n">dq</span><span class="p">,</span> <span class="n">dk</span><span class="p">,</span> <span class="n">dv</span><span class="p">,</span>
<span class="n">l</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span>
<span class="n">delta</span><span class="p">,</span>
<span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
<span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
<span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
<span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">grid</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
<span class="n">BLOCK_M</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">,</span>
<span class="n">BLOCK_DMODEL</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_DMODEL</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span>
<span class="n">num_stages</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">dq</span><span class="p">,</span> <span class="n">dk</span><span class="p">,</span> <span class="n">dv</span><span class="p">,</span> <span class="kc">None</span>
<span class="n">attention</span> <span class="o">=</span> <span class="n">_attention</span><span class="o">.</span><span class="n">apply</span>
<span class="nd">@pytest</span><span class="o">.</span><span class="n">mark</span><span class="o">.</span><span class="n">parametrize</span><span class="p">(</span><span class="s1">&#39;Z, H, N_CTX, D_HEAD&#39;</span><span class="p">,</span> <span class="p">[(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2048</span><span class="p">,</span> <span class="mi">64</span><span class="p">)])</span>
<span class="k">def</span> <span class="nf">test_op</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">):</span>
<span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">20</span><span class="p">)</span>
<span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">.5</span><span class="p">)</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">()</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">.5</span><span class="p">)</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">()</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">.5</span><span class="p">)</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">()</span>
<span class="n">sm_scale</span> <span class="o">=</span> <span class="mf">0.3</span>
<span class="n">dout</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">q</span><span class="p">)</span>
<span class="c1"># reference implementation</span>
<span class="n">M</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tril</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">N_CTX</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">))</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="o">*</span> <span class="n">sm_scale</span>
<span class="k">for</span> <span class="n">z</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">Z</span><span class="p">):</span>
<span class="k">for</span> <span class="n">h</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">H</span><span class="p">):</span>
<span class="n">p</span><span class="p">[:,</span> <span class="p">:,</span> <span class="n">M</span> <span class="o">==</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s2">&quot;-inf&quot;</span><span class="p">)</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">half</span><span class="p">()</span>
<span class="n">ref_out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="n">ref_out</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dout</span><span class="p">)</span>
<span class="n">ref_dv</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
<span class="n">ref_dk</span><span class="p">,</span> <span class="n">k</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
<span class="n">ref_dq</span><span class="p">,</span> <span class="n">q</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
<span class="c1"># triton implementation</span>
<span class="n">tri_out</span> <span class="o">=</span> <span class="n">attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">)</span>
<span class="n">tri_out</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dout</span><span class="p">)</span>
<span class="n">tri_dv</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
<span class="n">tri_dk</span><span class="p">,</span> <span class="n">k</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
<span class="n">tri_dq</span><span class="p">,</span> <span class="n">q</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
<span class="c1"># compare</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">ref_out</span><span class="p">,</span> <span class="n">tri_out</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">ref_dv</span><span class="p">,</span> <span class="n">tri_dv</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">ref_dk</span><span class="p">,</span> <span class="n">tri_dk</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">ref_dq</span><span class="p">,</span> <span class="n">tri_dq</span><span class="p">)</span>
<span class="k">try</span><span class="p">:</span>
<span class="kn">from</span> <span class="nn">flash_attn.flash_attn_interface</span> <span class="kn">import</span> <span class="n">flash_attn_func</span>
<span class="n">HAS_FLASH</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">except</span> <span class="ne">BaseException</span><span class="p">:</span>
<span class="n">HAS_FLASH</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">BATCH</span><span class="p">,</span> <span class="n">N_HEADS</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span> <span class="o">=</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">48</span><span class="p">,</span> <span class="mi">4096</span><span class="p">,</span> <span class="mi">64</span>
<span class="c1"># vary seq length for fixed head and batch=4</span>
<span class="n">configs</span> <span class="o">=</span> <span class="p">[</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;N_CTX&#39;</span><span class="p">],</span>
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">16</span><span class="p">)],</span>
<span class="n">line_arg</span><span class="o">=</span><span class="s1">&#39;provider&#39;</span><span class="p">,</span>
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;triton&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="p">([</span><span class="s1">&#39;flash&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="n">HAS_FLASH</span> <span class="k">else</span> <span class="p">[]),</span>
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;Triton&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="p">([</span><span class="s1">&#39;Flash&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="n">HAS_FLASH</span> <span class="k">else</span> <span class="p">[]),</span>
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">&#39;red&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;blue&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">)],</span>
<span class="n">ylabel</span><span class="o">=</span><span class="s1">&#39;ms&#39;</span><span class="p">,</span>
<span class="n">plot_name</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;fused-attention-batch</span><span class="si">{</span><span class="n">BATCH</span><span class="si">}</span><span class="s1">-head</span><span class="si">{</span><span class="n">N_HEADS</span><span class="si">}</span><span class="s1">-d</span><span class="si">{</span><span class="n">D_HEAD</span><span class="si">}</span><span class="s1">-</span><span class="si">{</span><span class="n">mode</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span>
<span class="n">args</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;H&#39;</span><span class="p">:</span> <span class="n">N_HEADS</span><span class="p">,</span> <span class="s1">&#39;BATCH&#39;</span><span class="p">:</span> <span class="n">BATCH</span><span class="p">,</span> <span class="s1">&#39;D_HEAD&#39;</span><span class="p">:</span> <span class="n">D_HEAD</span><span class="p">,</span> <span class="s1">&#39;dtype&#39;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="s1">&#39;mode&#39;</span><span class="p">:</span> <span class="n">mode</span><span class="p">}</span>
<span class="p">)</span> <span class="k">for</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;bwd&#39;</span><span class="p">]]</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span><span class="n">configs</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">bench_flash_attention</span><span class="p">(</span><span class="n">BATCH</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">,</span> <span class="n">mode</span><span class="p">,</span> <span class="n">provider</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;fwd&#39;</span><span class="p">,</span> <span class="s1">&#39;bwd&#39;</span><span class="p">]</span>
<span class="n">warmup</span> <span class="o">=</span> <span class="mi">25</span>
<span class="n">rep</span> <span class="o">=</span> <span class="mi">100</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s2">&quot;triton&quot;</span><span class="p">:</span>
<span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">BATCH</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">BATCH</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">BATCH</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">sm_scale</span> <span class="o">=</span> <span class="mf">1.3</span>
<span class="n">fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;bwd&#39;</span><span class="p">:</span>
<span class="n">o</span> <span class="o">=</span> <span class="n">fn</span><span class="p">()</span>
<span class="n">do</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>
<span class="n">fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">do</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">percentiles</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="n">warmup</span><span class="p">,</span> <span class="n">rep</span><span class="o">=</span><span class="n">rep</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ms</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s2">&quot;flash&quot;</span><span class="p">:</span>
<span class="n">lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">BATCH</span><span class="p">,),</span> <span class="n">fill_value</span><span class="o">=</span><span class="n">N_CTX</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="n">cu_seqlens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BATCH</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">cu_seqlens</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">=</span> <span class="n">lengths</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">BATCH</span> <span class="o">*</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">flash_attn_func</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="n">cu_seqlens</span><span class="p">,</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">causal</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;bwd&#39;</span><span class="p">:</span>
<span class="n">o</span> <span class="o">=</span> <span class="n">fn</span><span class="p">()</span>
<span class="n">do</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>
<span class="n">fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">do</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">percentiles</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="n">warmup</span><span class="p">,</span> <span class="n">rep</span><span class="o">=</span><span class="n">rep</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ms</span>
<span class="c1"># only works on A100 at the moment</span>
<span class="c1"># bench_flash_attention.run(save_path=&#39;.&#39;, print_data=True)</span>
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 0.072 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-06-fused-attention-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/54a35f6ec55f9746935b9566fb6bb1df/06-fused-attention.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">06-fused-attention.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">06-fused-attention.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="07-libdevice-function.html" class="btn btn-neutral float-right" title="Libdevice function" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="05-layer-norm.html" class="btn btn-neutral float-left" title="Layer Normalization" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: master
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
<dl>
<dt>Tags</dt>
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
</dl>
<dl>
<dt>Branches</dt>
<dd><a href="06-fused-attention.html">master</a></dd>
</dl>
</div>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

View File

@@ -0,0 +1,357 @@
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Libdevice function &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="triton" href="../../python-api/triton.html" />
<link rel="prev" title="Fused Attention" href="06-fused-attention.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Libdevice function</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#asin-kernel">asin Kernel</a></li>
<li class="toctree-l3"><a class="reference internal" href="#using-the-default-libdevice-library-path">Using the default libdevice library path</a></li>
<li class="toctree-l3"><a class="reference internal" href="#customize-the-libdevice-library-path">Customize the libdevice library path</a></li>
</ul>
</li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Libdevice function</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/07-libdevice-function.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-07-libdevice-function-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="libdevice-function">
<span id="sphx-glr-getting-started-tutorials-07-libdevice-function-py"></span><h1>Libdevice function<a class="headerlink" href="#libdevice-function" title="Permalink to this headline"></a></h1>
<p>Triton can invoke a custom function from an external library.
In this example, we will use the <cite>libdevice</cite> library to apply <cite>asin</cite> on a tensor.
Please refer to <a class="reference external" href="https://docs.nvidia.com/cuda/libdevice-users-guide/index.html">https://docs.nvidia.com/cuda/libdevice-users-guide/index.html</a> regarding the semantics of all available libdevice functions.</p>
<p>In <cite>trition/language/libdevice.py</cite>, we try to aggregate functions with the same computation but different data types together.
For example, both <cite>__nv_asin</cite> and <cite>__nvasinf</cite> calculate the principal value of the arc sine of the input, but <cite>__nv_asin</cite> operates on <cite>double</cite> and <cite>__nv_asinf</cite> operates on <cite>float</cite>.
Using triton, you can simply call <cite>tl.libdevice.asinf</cite>.
triton automatically selects the correct underlying device function to invoke based on input and output types.</p>
<div class="section" id="asin-kernel">
<h2>asin Kernel<a class="headerlink" href="#asin-kernel" title="Permalink to this headline"></a></h2>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">asin_kernel</span><span class="p">(</span>
<span class="n">x_ptr</span><span class="p">,</span>
<span class="n">y_ptr</span><span class="p">,</span>
<span class="n">n_elements</span><span class="p">,</span>
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
<span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o">&lt;</span> <span class="n">n_elements</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">libdevice</span><span class="o">.</span><span class="n">asin</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">y_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="using-the-default-libdevice-library-path">
<h2>Using the default libdevice library path<a class="headerlink" href="#using-the-default-libdevice-library-path" title="Permalink to this headline"></a></h2>
<p>We can use the default libdevice library path encoded in <cite>triton/language/libdevice.py</cite></p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">size</span> <span class="o">=</span> <span class="mi">98432</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">output_triton</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">output_torch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">asin</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">is_cuda</span> <span class="ow">and</span> <span class="n">output_triton</span><span class="o">.</span><span class="n">is_cuda</span>
<span class="n">n_elements</span> <span class="o">=</span> <span class="n">output_torch</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">n_elements</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE&#39;</span><span class="p">]),)</span>
<span class="n">asin_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">output_triton</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">output_torch</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">output_triton</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;The maximum difference between torch and triton is &#39;</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">output_torch</span> <span class="o">-</span> <span class="n">output_triton</span><span class="p">))</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device=&#39;cuda:0&#39;)
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device=&#39;cuda:0&#39;)
The maximum difference between torch and triton is 2.384185791015625e-07
</pre></div>
</div>
</div>
<div class="section" id="customize-the-libdevice-library-path">
<h2>Customize the libdevice library path<a class="headerlink" href="#customize-the-libdevice-library-path" title="Permalink to this headline"></a></h2>
<p>We can also customize the libdevice library path by passing the path to the <cite>libdevice</cite> library to the <cite>asin</cite> kernel.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">output_triton</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">asin_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">output_triton</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span>
<span class="n">extern_libs</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;libdevice&#39;</span><span class="p">:</span> <span class="s1">&#39;/usr/local/cuda/nvvm/libdevice/libdevice.10.bc&#39;</span><span class="p">})</span>
<span class="nb">print</span><span class="p">(</span><span class="n">output_torch</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">output_triton</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;The maximum difference between torch and triton is &#39;</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">output_torch</span> <span class="o">-</span> <span class="n">output_triton</span><span class="p">))</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device=&#39;cuda:0&#39;)
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device=&#39;cuda:0&#39;)
The maximum difference between torch and triton is 2.384185791015625e-07
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 0.501 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-07-libdevice-function-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/3ff29f967ace7985da24aab10352fc76/07-libdevice-function.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">07-libdevice-function.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/1bc2e471d2fb0ec017c4d1d0890db4e2/07-libdevice-function.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">07-libdevice-function.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../../python-api/triton.html" class="btn btn-neutral float-right" title="triton" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="06-fused-attention.html" class="btn btn-neutral float-left" title="Fused Attention" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: master
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
<dl>
<dt>Tags</dt>
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
</dl>
<dl>
<dt>Branches</dt>
<dd><a href="07-libdevice-function.html">master</a></dd>
</dl>
</div>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

View File

@@ -101,6 +101,8 @@
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
</ul>
</li>
</ul>
@@ -219,6 +221,18 @@ pip install -e <span class="s1">&#39;./python[tutorials]&#39;</span>
</div>
</div><div class="toctree-wrapper compound">
</div>
<div class="sphx-glr-thumbcontainer" tooltip="Fused Attention"><div class="figure align-default" id="id6">
<img alt="Fused Attention" src="../../_images/sphx_glr_06-fused-attention_thumb.png" />
<p class="caption"><span class="caption-text"><a class="reference internal" href="06-fused-attention.html#sphx-glr-getting-started-tutorials-06-fused-attention-py"><span class="std std-ref">Fused Attention</span></a></span><a class="headerlink" href="#id6" title="Permalink to this image"></a></p>
</div>
</div><div class="toctree-wrapper compound">
</div>
<div class="sphx-glr-thumbcontainer" tooltip="In trition/language/libdevice.py, we try to aggregate functions with the same computation but d..."><div class="figure align-default" id="id7">
<img alt="Libdevice function" src="../../_images/sphx_glr_07-libdevice-function_thumb.png" />
<p class="caption"><span class="caption-text"><a class="reference internal" href="07-libdevice-function.html#sphx-glr-getting-started-tutorials-07-libdevice-function-py"><span class="std std-ref">Libdevice function</span></a></span><a class="headerlink" href="#id7" title="Permalink to this image"></a></p>
</div>
</div><div class="toctree-wrapper compound">
</div>
<div class="sphx-glr-clear"></div><div class="sphx-glr-footer class sphx-glr-footer-gallery docutils container">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">all</span> <span class="pre">examples</span> <span class="pre">in</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">tutorials_python.zip</span></code></a></p>

View File

@@ -174,7 +174,7 @@
<div class="section" id="computation-times">
<span id="sphx-glr-getting-started-tutorials-sg-execution-times"></span><h1>Computation times<a class="headerlink" href="#computation-times" title="Permalink to this headline"></a></h1>
<p><strong>16:10.599</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
<p><strong>18:09.339</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
<table class="docutils align-default">
<colgroup>
<col style="width: 85%" />
@@ -183,23 +183,31 @@
</colgroup>
<tbody>
<tr class="row-odd"><td><p><a class="reference internal" href="03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">Matrix Multiplication</span></a> (<code class="docutils literal notranslate"><span class="pre">03-matrix-multiplication.py</span></code>)</p></td>
<td><p>05:52.578</p></td>
<td><p>07:13.827</p></td>
<td><p>0.0 MB</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="05-layer-norm.html#sphx-glr-getting-started-tutorials-05-layer-norm-py"><span class="std std-ref">Layer Normalization</span></a> (<code class="docutils literal notranslate"><span class="pre">05-layer-norm.py</span></code>)</p></td>
<td><p>05:24.641</p></td>
<td><p>05:32.552</p></td>
<td><p>0.0 MB</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td>
<td><p>03:18.076</p></td>
<td><p>03:32.089</p></td>
<td><p>0.0 MB</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td>
<td><p>01:34.829</p></td>
<td><p>01:50.020</p></td>
<td><p>0.0 MB</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="04-low-memory-dropout.html#sphx-glr-getting-started-tutorials-04-low-memory-dropout-py"><span class="std std-ref">Low-Memory Dropout</span></a> (<code class="docutils literal notranslate"><span class="pre">04-low-memory-dropout.py</span></code>)</p></td>
<td><p>00:00.476</p></td>
<tr class="row-odd"><td><p><a class="reference internal" href="07-libdevice-function.html#sphx-glr-getting-started-tutorials-07-libdevice-function-py"><span class="std std-ref">Libdevice function</span></a> (<code class="docutils literal notranslate"><span class="pre">07-libdevice-function.py</span></code>)</p></td>
<td><p>00:00.501</p></td>
<td><p>0.0 MB</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="04-low-memory-dropout.html#sphx-glr-getting-started-tutorials-04-low-memory-dropout-py"><span class="std std-ref">Low-Memory Dropout</span></a> (<code class="docutils literal notranslate"><span class="pre">04-low-memory-dropout.py</span></code>)</p></td>
<td><p>00:00.279</p></td>
<td><p>0.0 MB</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="06-fused-attention.html#sphx-glr-getting-started-tutorials-06-fused-attention-py"><span class="std std-ref">Fused Attention</span></a> (<code class="docutils literal notranslate"><span class="pre">06-fused-attention.py</span></code>)</p></td>
<td><p>00:00.072</p></td>
<td><p>0.0 MB</p></td>
</tr>
</tbody>

Binary file not shown.

View File

@@ -197,7 +197,7 @@
<h1>triton.language.dot<a class="headerlink" href="#triton-language-dot" title="Permalink to this headline"></a></h1>
<dl class="py function">
<dt class="sig sig-object py" id="triton.language.dot">
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">dot</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">other</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">allow_tf32</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.dot" title="Permalink to this definition"></a></dt>
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">dot</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">other</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">trans_a</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">trans_b</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">allow_tf32</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.dot" title="Permalink to this definition"></a></dt>
<dd><p>Returns the matrix product of two blocks.</p>
<p>The two blocks must be two dimensionals and have compatible inner dimensions.</p>
<dl class="field-list simple">

View File

@@ -200,7 +200,7 @@
<h1>triton.language.store<a class="headerlink" href="#triton-language-store" title="Permalink to this headline"></a></h1>
<dl class="py function">
<dt class="sig sig-object py" id="triton.language.store">
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">store</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pointer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.store" title="Permalink to this definition"></a></dt>
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">store</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pointer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">eviction_policy</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">''</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.store" title="Permalink to this definition"></a></dt>
<dd><p>Stores <code class="code docutils literal notranslate"><span class="pre">value</span></code> tensor of elements in memory, element-wise, at the memory locations specified by <code class="code docutils literal notranslate"><span class="pre">pointer</span></code>.</p>
<p><code class="code docutils literal notranslate"><span class="pre">value</span></code> is implicitly broadcast to <code class="code docutils literal notranslate"><span class="pre">pointer.shape</span></code> and typecast to <code class="code docutils literal notranslate"><span class="pre">pointer.dtype.element_ty</span></code>.</p>
<dl class="field-list simple">

Some files were not shown because too many files have changed in this diff Show More