2022-07-14 07:22:19 +00:00
<!DOCTYPE html>
< html class = "writer-html5" lang = "en" >
< head >
< meta charset = "utf-8" / >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" / >
< title > Fused Attention — Triton documentation< / title >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-binder.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-dataframe.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-rendered-html.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/css/custom.css" type = "text/css" / >
<!-- [if lt IE 9]>
< script src = "../../_static/js/html5shiv.min.js" > < / script >
<![endif]-->
< script type = "text/javascript" id = "documentation_options" data-url_root = "../../" src = "../../_static/documentation_options.js" > < / script >
< script data-url_root = "../../" id = "documentation_options" src = "../../_static/documentation_options.js" > < / script >
< script src = "../../_static/jquery.js" > < / script >
< script src = "../../_static/underscore.js" > < / script >
< script src = "../../_static/doctools.js" > < / script >
< script type = "text/javascript" src = "../../_static/js/theme.js" > < / script >
< link rel = "index" title = "Index" href = "../../genindex.html" / >
< link rel = "search" title = "Search" href = "../../search.html" / >
< link rel = "next" title = "Libdevice function" href = "07-libdevice-function.html" / >
< link rel = "prev" title = "Layer Normalization" href = "05-layer-norm.html" / >
< / head >
< body class = "wy-body-for-nav" >
< div class = "wy-grid-for-nav" >
< nav data-toggle = "wy-nav-shift" class = "wy-nav-side" >
< div class = "wy-side-scroll" >
< div class = "wy-side-nav-search" >
< a href = "../../index.html" class = "icon icon-home" > Triton
< / a >
< div role = "search" >
< form id = "rtd-search-form" class = "wy-form" action = "../../search.html" method = "get" >
< input type = "text" name = "q" placeholder = "Search docs" / >
< input type = "hidden" name = "check_keywords" value = "yes" / >
< input type = "hidden" name = "area" value = "default" / >
< / form >
< / div >
< / div >
< div class = "wy-menu wy-menu-vertical" data-spy = "affix" role = "navigation" aria-label = "main navigation" >
< p class = "caption" role = "heading" > < span class = "caption-text" > Getting Started< / span > < / p >
< ul class = "current" >
< li class = "toctree-l1" > < a class = "reference internal" href = "../installation.html" > Installation< / a > < / li >
< li class = "toctree-l1 current" > < a class = "reference internal" href = "index.html" > Tutorials< / a > < ul class = "current" >
< li class = "toctree-l2" > < a class = "reference internal" href = "01-vector-add.html" > Vector Addition< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "02-fused-softmax.html" > Fused Softmax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "03-matrix-multiplication.html" > Matrix Multiplication< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "04-low-memory-dropout.html" > Low-Memory Dropout< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "05-layer-norm.html" > Layer Normalization< / a > < / li >
< li class = "toctree-l2 current" > < a class = "current reference internal" href = "#" > Fused Attention< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "07-libdevice-function.html" > Libdevice function< / a > < / li >
< / ul >
< / li >
< / ul >
< p class = "caption" role = "heading" > < span class = "caption-text" > Python API< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.html" > triton< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.language.html" > triton.language< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.testing.html" > triton.testing< / a > < / li >
< / ul >
< p class = "caption" role = "heading" > < span class = "caption-text" > Programming Guide< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-1/introduction.html" > Introduction< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-2/related-work.html" > Related Work< / a > < / li >
< / ul >
< / div >
< / div >
< / nav >
< section data-toggle = "wy-nav-shift" class = "wy-nav-content-wrap" >
< nav class = "wy-nav-top" aria-label = "top navigation" >
< i data-toggle = "wy-nav-top" class = "fa fa-bars" > < / i >
< a href = "../../index.html" > Triton< / a >
< / nav >
< div class = "wy-nav-content" >
< div class = "rst-content" >
< div role = "navigation" aria-label = "breadcrumbs navigation" >
< ul class = "wy-breadcrumbs" >
< li > < a href = "../../index.html" class = "icon icon-home" > < / a > » < / li >
< li > < a href = "index.html" > Tutorials< / a > » < / li >
< li > Fused Attention< / li >
< li class = "wy-breadcrumbs-aside" >
< a href = "../../_sources/getting-started/tutorials/06-fused-attention.rst.txt" rel = "nofollow" > View page source< / a >
< / li >
< / ul >
< hr / >
< / div >
< div role = "main" class = "document" itemscope = "itemscope" itemtype = "http://schema.org/Article" >
< div itemprop = "articleBody" >
< div class = "sphx-glr-download-link-note admonition note" >
< p class = "admonition-title" > Note< / p >
< p > Click < a class = "reference internal" href = "#sphx-glr-download-getting-started-tutorials-06-fused-attention-py" > < span class = "std std-ref" > here< / span > < / a >
to download the full example code< / p >
< / div >
< div class = "sphx-glr-example-title section" id = "fused-attention" >
< span id = "sphx-glr-getting-started-tutorials-06-fused-attention-py" > < / span > < h1 > Fused Attention< a class = "headerlink" href = "#fused-attention" title = "Permalink to this headline" > ¶< / a > < / h1 >
< p > This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., < a class = "reference external" href = "https://arxiv.org/pdf/2205.14135v2.pdf" > https://arxiv.org/pdf/2205.14135v2.pdf< / a > ; Rabe and Staats < a class = "reference external" href = "https://arxiv.org/pdf/2112.05682v2.pdf" > https://arxiv.org/pdf/2112.05682v2.pdf< / a > )< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "nn" > pytest< / span >
< span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton.language< / span > < span class = "k" > as< / span > < span class = "nn" > tl< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > _fwd_kernel< / span > < span class = "p" > (< / span >
< span class = "n" > Q< / span > < span class = "p" > ,< / span > < span class = "n" > K< / span > < span class = "p" > ,< / span > < span class = "n" > V< / span > < span class = "p" > ,< / span > < span class = "n" > sm_scale< / span > < span class = "p" > ,< / span >
< span class = "n" > TMP< / span > < span class = "p" > ,< / span > < span class = "n" > L< / span > < span class = "p" > ,< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "c1" > # NOTE: TMP is a scratchpad buffer to workaround a compiler bug< / span >
< span class = "n" > Out< / span > < span class = "p" > ,< / span >
< span class = "n" > stride_qz< / span > < span class = "p" > ,< / span > < span class = "n" > stride_qh< / span > < span class = "p" > ,< / span > < span class = "n" > stride_qm< / span > < span class = "p" > ,< / span > < span class = "n" > stride_qk< / span > < span class = "p" > ,< / span >
< span class = "n" > stride_kz< / span > < span class = "p" > ,< / span > < span class = "n" > stride_kh< / span > < span class = "p" > ,< / span > < span class = "n" > stride_kn< / span > < span class = "p" > ,< / span > < span class = "n" > stride_kk< / span > < span class = "p" > ,< / span >
< span class = "n" > stride_vz< / span > < span class = "p" > ,< / span > < span class = "n" > stride_vh< / span > < span class = "p" > ,< / span > < span class = "n" > stride_vk< / span > < span class = "p" > ,< / span > < span class = "n" > stride_vn< / span > < span class = "p" > ,< / span >
< span class = "n" > stride_oz< / span > < span class = "p" > ,< / span > < span class = "n" > stride_oh< / span > < span class = "p" > ,< / span > < span class = "n" > stride_om< / span > < span class = "p" > ,< / span > < span class = "n" > stride_on< / span > < span class = "p" > ,< / span >
< span class = "n" > Z< / span > < span class = "p" > ,< / span > < span class = "n" > H< / span > < span class = "p" > ,< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_M< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_DMODEL< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_N< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "p" > ):< / span >
< span class = "n" > start_m< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > off_hz< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "c1" > # initialize offsets< / span >
< span class = "n" > offs_m< / span > < span class = "o" > =< / span > < span class = "n" > start_m< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK_M< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > )< / span >
< span class = "n" > offs_n< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_N< / span > < span class = "p" > )< / span >
< span class = "n" > offs_d< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_DMODEL< / span > < span class = "p" > )< / span >
< span class = "n" > off_q< / span > < span class = "o" > =< / span > < span class = "n" > off_hz< / span > < span class = "o" > *< / span > < span class = "n" > stride_qh< / span > < span class = "o" > +< / span > < span class = "n" > offs_m< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > stride_qm< / span > < span class = "o" > +< / span > < span class = "n" > offs_d< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span > < span class = "o" > *< / span > < span class = "n" > stride_qk< / span >
< span class = "n" > off_k< / span > < span class = "o" > =< / span > < span class = "n" > off_hz< / span > < span class = "o" > *< / span > < span class = "n" > stride_qh< / span > < span class = "o" > +< / span > < span class = "n" > offs_n< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > stride_kn< / span > < span class = "o" > +< / span > < span class = "n" > offs_d< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span > < span class = "o" > *< / span > < span class = "n" > stride_kk< / span >
< span class = "n" > off_v< / span > < span class = "o" > =< / span > < span class = "n" > off_hz< / span > < span class = "o" > *< / span > < span class = "n" > stride_qh< / span > < span class = "o" > +< / span > < span class = "n" > offs_n< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > stride_qm< / span > < span class = "o" > +< / span > < span class = "n" > offs_d< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span > < span class = "o" > *< / span > < span class = "n" > stride_qk< / span >
< span class = "c1" > # Initialize pointers to Q, K, V< / span >
< span class = "n" > q_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > Q< / span > < span class = "o" > +< / span > < span class = "n" > off_q< / span >
< span class = "n" > k_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > K< / span > < span class = "o" > +< / span > < span class = "n" > off_k< / span >
< span class = "n" > v_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > V< / span > < span class = "o" > +< / span > < span class = "n" > off_v< / span >
< span class = "c1" > # initialize pointer to m and l< / span >
< span class = "n" > t_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > TMP< / span > < span class = "o" > +< / span > < span class = "n" > off_hz< / span > < span class = "o" > *< / span > < span class = "n" > N_CTX< / span > < span class = "o" > +< / span > < span class = "n" > offs_m< / span >
< span class = "n" > m_i< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span > < span class = "o" > -< / span > < span class = "nb" > float< / span > < span class = "p" > (< / span > < span class = "s2" > " inf" < / span > < span class = "p" > )< / span >
< span class = "n" > l_i< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > acc< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_DMODEL< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "c1" > # load q: it will stay in SRAM throughout< / span >
< span class = "n" > q< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > q_ptrs< / span > < span class = "p" > )< / span >
< span class = "c1" > # loop over k, v and update accumulator< / span >
< span class = "k" > for< / span > < span class = "n" > start_n< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "p" > (< / span > < span class = "n" > start_m< / span > < span class = "o" > +< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_N< / span > < span class = "p" > ):< / span >
< span class = "n" > start_n< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > multiple_of< / span > < span class = "p" > (< / span > < span class = "n" > start_n< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_N< / span > < span class = "p" > )< / span >
< span class = "c1" > # -- compute qk ----< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > k_ptrs< / span > < span class = "o" > +< / span > < span class = "n" > start_n< / span > < span class = "o" > *< / span > < span class = "n" > stride_kn< / span > < span class = "p" > )< / span >
< span class = "n" > qk< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_N< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > qk< / span > < span class = "o" > +=< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > dot< / span > < span class = "p" > (< / span > < span class = "n" > q< / span > < span class = "p" > ,< / span > < span class = "n" > k< / span > < span class = "p" > ,< / span > < span class = "n" > trans_b< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > qk< / span > < span class = "o" > *=< / span > < span class = "n" > sm_scale< / span >
< span class = "n" > qk< / span > < span class = "o" > +=< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > where< / span > < span class = "p" > (< / span > < span class = "n" > offs_m< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > > =< / span > < span class = "p" > (< / span > < span class = "n" > start_n< / span > < span class = "o" > +< / span > < span class = "n" > offs_n< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]),< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "nb" > float< / span > < span class = "p" > (< / span > < span class = "s2" > " -inf" < / span > < span class = "p" > ))< / span >
< span class = "c1" > # -- compute m_ij, p, l_ij< / span >
< span class = "n" > m_ij< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > max< / span > < span class = "p" > (< / span > < span class = "n" > qk< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "n" > p< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > exp< / span > < span class = "p" > (< / span > < span class = "n" > qk< / span > < span class = "o" > -< / span > < span class = "n" > m_ij< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ])< / span >
< span class = "n" > l_ij< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > p< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "c1" > # -- update m_i and l_i< / span >
< span class = "n" > m_i_new< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > maximum< / span > < span class = "p" > (< / span > < span class = "n" > m_i< / span > < span class = "p" > ,< / span > < span class = "n" > m_ij< / span > < span class = "p" > )< / span >
< span class = "n" > alpha< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > exp< / span > < span class = "p" > (< / span > < span class = "n" > m_i< / span > < span class = "o" > -< / span > < span class = "n" > m_i_new< / span > < span class = "p" > )< / span >
< span class = "n" > beta< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > exp< / span > < span class = "p" > (< / span > < span class = "n" > m_ij< / span > < span class = "o" > -< / span > < span class = "n" > m_i_new< / span > < span class = "p" > )< / span >
< span class = "n" > l_i_new< / span > < span class = "o" > =< / span > < span class = "n" > alpha< / span > < span class = "o" > *< / span > < span class = "n" > l_i< / span > < span class = "o" > +< / span > < span class = "n" > beta< / span > < span class = "o" > *< / span > < span class = "n" > l_ij< / span >
< span class = "c1" > # -- update output accumulator --< / span >
< span class = "c1" > # scale p< / span >
< span class = "n" > p_scale< / span > < span class = "o" > =< / span > < span class = "n" > beta< / span > < span class = "o" > /< / span > < span class = "n" > l_i_new< / span >
< span class = "n" > p< / span > < span class = "o" > =< / span > < span class = "n" > p< / span > < span class = "o" > *< / span > < span class = "n" > p_scale< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span >
< span class = "c1" > # scale acc< / span >
< span class = "n" > acc_scale< / span > < span class = "o" > =< / span > < span class = "n" > l_i< / span > < span class = "o" > /< / span > < span class = "n" > l_i_new< / span > < span class = "o" > *< / span > < span class = "n" > alpha< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > t_ptrs< / span > < span class = "p" > ,< / span > < span class = "n" > acc_scale< / span > < span class = "p" > )< / span >
< span class = "n" > acc_scale< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > t_ptrs< / span > < span class = "p" > )< / span > < span class = "c1" > # BUG: have to store and immediately load< / span >
< span class = "n" > acc< / span > < span class = "o" > =< / span > < span class = "n" > acc< / span > < span class = "o" > *< / span > < span class = "n" > acc_scale< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span >
< span class = "c1" > # update acc< / span >
< span class = "n" > v< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > v_ptrs< / span > < span class = "o" > +< / span > < span class = "n" > start_n< / span > < span class = "o" > *< / span > < span class = "n" > stride_vk< / span > < span class = "p" > )< / span >
< span class = "n" > p< / span > < span class = "o" > =< / span > < span class = "n" > p< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float16< / span > < span class = "p" > )< / span >
< span class = "n" > acc< / span > < span class = "o" > +=< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > dot< / span > < span class = "p" > (< / span > < span class = "n" > p< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "p" > )< / span >
< span class = "c1" > # update m_i and l_i< / span >
< span class = "n" > l_i< / span > < span class = "o" > =< / span > < span class = "n" > l_i_new< / span >
< span class = "n" > m_i< / span > < span class = "o" > =< / span > < span class = "n" > m_i_new< / span >
< span class = "c1" > # rematerialize offsets to save registers< / span >
< span class = "n" > start_m< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > offs_m< / span > < span class = "o" > =< / span > < span class = "n" > start_m< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK_M< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > )< / span >
< span class = "c1" > # write back l and m< / span >
< span class = "n" > l_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > L< / span > < span class = "o" > +< / span > < span class = "n" > off_hz< / span > < span class = "o" > *< / span > < span class = "n" > N_CTX< / span > < span class = "o" > +< / span > < span class = "n" > offs_m< / span >
< span class = "n" > m_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > M< / span > < span class = "o" > +< / span > < span class = "n" > off_hz< / span > < span class = "o" > *< / span > < span class = "n" > N_CTX< / span > < span class = "o" > +< / span > < span class = "n" > offs_m< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > l_ptrs< / span > < span class = "p" > ,< / span > < span class = "n" > l_i< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > m_ptrs< / span > < span class = "p" > ,< / span > < span class = "n" > m_i< / span > < span class = "p" > )< / span >
< span class = "c1" > # initialize pointers to output< / span >
< span class = "n" > offs_n< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_DMODEL< / span > < span class = "p" > )< / span >
< span class = "n" > off_o< / span > < span class = "o" > =< / span > < span class = "n" > off_hz< / span > < span class = "o" > *< / span > < span class = "n" > stride_oh< / span > < span class = "o" > +< / span > < span class = "n" > offs_m< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > stride_om< / span > < span class = "o" > +< / span > < span class = "n" > offs_n< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span > < span class = "o" > *< / span > < span class = "n" > stride_on< / span >
< span class = "n" > out_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > Out< / span > < span class = "o" > +< / span > < span class = "n" > off_o< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > out_ptrs< / span > < span class = "p" > ,< / span > < span class = "n" > acc< / span > < span class = "p" > )< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > _bwd_preprocess< / span > < span class = "p" > (< / span >
< span class = "n" > Out< / span > < span class = "p" > ,< / span > < span class = "n" > DO< / span > < span class = "p" > ,< / span > < span class = "n" > L< / span > < span class = "p" > ,< / span >
< span class = "n" > NewDO< / span > < span class = "p" > ,< / span > < span class = "n" > Delta< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_M< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span > < span class = "n" > D_HEAD< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "p" > ):< / span >
< span class = "n" > off_m< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK_M< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > )< / span >
< span class = "n" > off_n< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > D_HEAD< / span > < span class = "p" > )< / span >
< span class = "c1" > # load< / span >
< span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Out< / span > < span class = "o" > +< / span > < span class = "n" > off_m< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > D_HEAD< / span > < span class = "o" > +< / span > < span class = "n" > off_n< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :])< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > do< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > DO< / span > < span class = "o" > +< / span > < span class = "n" > off_m< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > D_HEAD< / span > < span class = "o" > +< / span > < span class = "n" > off_n< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :])< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > denom< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > L< / span > < span class = "o" > +< / span > < span class = "n" > off_m< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "c1" > # compute< / span >
< span class = "n" > do< / span > < span class = "o" > =< / span > < span class = "n" > do< / span > < span class = "o" > /< / span > < span class = "n" > denom< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span >
< span class = "n" > delta< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > o< / span > < span class = "o" > *< / span > < span class = "n" > do< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "c1" > # write-back< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > NewDO< / span > < span class = "o" > +< / span > < span class = "n" > off_m< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > D_HEAD< / span > < span class = "o" > +< / span > < span class = "n" > off_n< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :],< / span > < span class = "n" > do< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > Delta< / span > < span class = "o" > +< / span > < span class = "n" > off_m< / span > < span class = "p" > ,< / span > < span class = "n" > delta< / span > < span class = "p" > )< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > _bwd_kernel< / span > < span class = "p" > (< / span >
< span class = "n" > Q< / span > < span class = "p" > ,< / span > < span class = "n" > K< / span > < span class = "p" > ,< / span > < span class = "n" > V< / span > < span class = "p" > ,< / span > < span class = "n" > sm_scale< / span > < span class = "p" > ,< / span > < span class = "n" > Out< / span > < span class = "p" > ,< / span > < span class = "n" > DO< / span > < span class = "p" > ,< / span >
< span class = "n" > DQ< / span > < span class = "p" > ,< / span > < span class = "n" > DK< / span > < span class = "p" > ,< / span > < span class = "n" > DV< / span > < span class = "p" > ,< / span >
< span class = "n" > L< / span > < span class = "p" > ,< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span >
< span class = "n" > D< / span > < span class = "p" > ,< / span >
< span class = "n" > stride_qz< / span > < span class = "p" > ,< / span > < span class = "n" > stride_qh< / span > < span class = "p" > ,< / span > < span class = "n" > stride_qm< / span > < span class = "p" > ,< / span > < span class = "n" > stride_qk< / span > < span class = "p" > ,< / span >
< span class = "n" > stride_kz< / span > < span class = "p" > ,< / span > < span class = "n" > stride_kh< / span > < span class = "p" > ,< / span > < span class = "n" > stride_kn< / span > < span class = "p" > ,< / span > < span class = "n" > stride_kk< / span > < span class = "p" > ,< / span >
< span class = "n" > stride_vz< / span > < span class = "p" > ,< / span > < span class = "n" > stride_vh< / span > < span class = "p" > ,< / span > < span class = "n" > stride_vk< / span > < span class = "p" > ,< / span > < span class = "n" > stride_vn< / span > < span class = "p" > ,< / span >
< span class = "n" > Z< / span > < span class = "p" > ,< / span > < span class = "n" > H< / span > < span class = "p" > ,< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span >
< span class = "n" > num_block< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_M< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_DMODEL< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_N< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "p" > ):< / span >
< span class = "n" > off_hz< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > off_z< / span > < span class = "o" > =< / span > < span class = "n" > off_hz< / span > < span class = "o" > //< / span > < span class = "n" > H< / span >
< span class = "n" > off_h< / span > < span class = "o" > =< / span > < span class = "n" > off_hz< / span > < span class = "o" > %< / span > < span class = "n" > H< / span >
< span class = "c1" > # offset pointers for batch/head< / span >
< span class = "n" > Q< / span > < span class = "o" > +=< / span > < span class = "n" > off_z< / span > < span class = "o" > *< / span > < span class = "n" > stride_qz< / span > < span class = "o" > +< / span > < span class = "n" > off_h< / span > < span class = "o" > *< / span > < span class = "n" > stride_qh< / span >
< span class = "n" > K< / span > < span class = "o" > +=< / span > < span class = "n" > off_z< / span > < span class = "o" > *< / span > < span class = "n" > stride_qz< / span > < span class = "o" > +< / span > < span class = "n" > off_h< / span > < span class = "o" > *< / span > < span class = "n" > stride_qh< / span >
< span class = "n" > V< / span > < span class = "o" > +=< / span > < span class = "n" > off_z< / span > < span class = "o" > *< / span > < span class = "n" > stride_qz< / span > < span class = "o" > +< / span > < span class = "n" > off_h< / span > < span class = "o" > *< / span > < span class = "n" > stride_qh< / span >
< span class = "n" > DO< / span > < span class = "o" > +=< / span > < span class = "n" > off_z< / span > < span class = "o" > *< / span > < span class = "n" > stride_qz< / span > < span class = "o" > +< / span > < span class = "n" > off_h< / span > < span class = "o" > *< / span > < span class = "n" > stride_qh< / span >
< span class = "n" > DQ< / span > < span class = "o" > +=< / span > < span class = "n" > off_z< / span > < span class = "o" > *< / span > < span class = "n" > stride_qz< / span > < span class = "o" > +< / span > < span class = "n" > off_h< / span > < span class = "o" > *< / span > < span class = "n" > stride_qh< / span >
< span class = "n" > DK< / span > < span class = "o" > +=< / span > < span class = "n" > off_z< / span > < span class = "o" > *< / span > < span class = "n" > stride_qz< / span > < span class = "o" > +< / span > < span class = "n" > off_h< / span > < span class = "o" > *< / span > < span class = "n" > stride_qh< / span >
< span class = "n" > DV< / span > < span class = "o" > +=< / span > < span class = "n" > off_z< / span > < span class = "o" > *< / span > < span class = "n" > stride_qz< / span > < span class = "o" > +< / span > < span class = "n" > off_h< / span > < span class = "o" > *< / span > < span class = "n" > stride_qh< / span >
< span class = "k" > for< / span > < span class = "n" > start_n< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > num_block< / span > < span class = "p" > ):< / span >
< span class = "n" > lo< / span > < span class = "o" > =< / span > < span class = "n" > start_n< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK_M< / span >
< span class = "c1" > # initialize row/col offsets< / span >
< span class = "n" > offs_qm< / span > < span class = "o" > =< / span > < span class = "n" > lo< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > )< / span >
< span class = "n" > offs_n< / span > < span class = "o" > =< / span > < span class = "n" > start_n< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK_M< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > )< / span >
< span class = "n" > offs_m< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_N< / span > < span class = "p" > )< / span >
< span class = "n" > offs_k< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_DMODEL< / span > < span class = "p" > )< / span >
< span class = "c1" > # initialize pointers to value-like data< / span >
< span class = "n" > q_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > Q< / span > < span class = "o" > +< / span > < span class = "p" > (< / span > < span class = "n" > offs_qm< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > stride_qm< / span > < span class = "o" > +< / span > < span class = "n" > offs_k< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span > < span class = "o" > *< / span > < span class = "n" > stride_qk< / span > < span class = "p" > )< / span >
< span class = "n" > k_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > K< / span > < span class = "o" > +< / span > < span class = "p" > (< / span > < span class = "n" > offs_n< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > stride_kn< / span > < span class = "o" > +< / span > < span class = "n" > offs_k< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span > < span class = "o" > *< / span > < span class = "n" > stride_kk< / span > < span class = "p" > )< / span >
< span class = "n" > v_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > V< / span > < span class = "o" > +< / span > < span class = "p" > (< / span > < span class = "n" > offs_n< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > stride_qm< / span > < span class = "o" > +< / span > < span class = "n" > offs_k< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span > < span class = "o" > *< / span > < span class = "n" > stride_qk< / span > < span class = "p" > )< / span >
< span class = "n" > do_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > DO< / span > < span class = "o" > +< / span > < span class = "p" > (< / span > < span class = "n" > offs_qm< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > stride_qm< / span > < span class = "o" > +< / span > < span class = "n" > offs_k< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span > < span class = "o" > *< / span > < span class = "n" > stride_qk< / span > < span class = "p" > )< / span >
< span class = "n" > dq_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > DQ< / span > < span class = "o" > +< / span > < span class = "p" > (< / span > < span class = "n" > offs_qm< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > stride_qm< / span > < span class = "o" > +< / span > < span class = "n" > offs_k< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span > < span class = "o" > *< / span > < span class = "n" > stride_qk< / span > < span class = "p" > )< / span >
< span class = "c1" > # pointer to row-wise quantities in value-like data< / span >
< span class = "n" > D_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > D< / span > < span class = "o" > +< / span > < span class = "n" > off_hz< / span > < span class = "o" > *< / span > < span class = "n" > N_CTX< / span >
< span class = "n" > m_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > M< / span > < span class = "o" > +< / span > < span class = "n" > off_hz< / span > < span class = "o" > *< / span > < span class = "n" > N_CTX< / span >
< span class = "c1" > # initialize dv amd dk< / span >
< span class = "n" > dv< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_DMODEL< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > dk< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_DMODEL< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "c1" > # k and v stay in SRAM throughout< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > k_ptrs< / span > < span class = "p" > )< / span >
< span class = "n" > v< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > v_ptrs< / span > < span class = "p" > )< / span >
< span class = "c1" > # loop over rows< / span >
< span class = "k" > for< / span > < span class = "n" > start_m< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "n" > lo< / span > < span class = "p" > ,< / span > < span class = "n" > num_block< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > ):< / span >
< span class = "n" > offs_m_curr< / span > < span class = "o" > =< / span > < span class = "n" > start_m< / span > < span class = "o" > +< / span > < span class = "n" > offs_m< / span >
< span class = "c1" > # load q, k, v, do on-chip< / span >
< span class = "n" > q< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > q_ptrs< / span > < span class = "p" > )< / span >
< span class = "c1" > # recompute p = softmax(qk, dim=-1).T< / span >
< span class = "c1" > # NOTE: `do` is pre-divided by `l`; no normalization here< / span >
< span class = "n" > qk< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > dot< / span > < span class = "p" > (< / span > < span class = "n" > q< / span > < span class = "p" > ,< / span > < span class = "n" > k< / span > < span class = "p" > ,< / span > < span class = "n" > trans_b< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > qk< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > where< / span > < span class = "p" > (< / span > < span class = "n" > offs_m_curr< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > > =< / span > < span class = "p" > (< / span > < span class = "n" > offs_n< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]),< / span > < span class = "n" > qk< / span > < span class = "p" > ,< / span > < span class = "nb" > float< / span > < span class = "p" > (< / span > < span class = "s2" > " -inf" < / span > < span class = "p" > ))< / span >
< span class = "n" > m< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > m_ptrs< / span > < span class = "o" > +< / span > < span class = "n" > offs_m_curr< / span > < span class = "p" > )< / span >
< span class = "n" > p< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > exp< / span > < span class = "p" > (< / span > < span class = "n" > qk< / span > < span class = "o" > *< / span > < span class = "n" > sm_scale< / span > < span class = "o" > -< / span > < span class = "n" > m< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ])< / span >
< span class = "c1" > # compute dv< / span >
< span class = "n" > do< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > do_ptrs< / span > < span class = "p" > )< / span >
< span class = "n" > dv< / span > < span class = "o" > +=< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > dot< / span > < span class = "p" > (< / span > < span class = "n" > p< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float16< / span > < span class = "p" > ),< / span > < span class = "n" > do< / span > < span class = "p" > ,< / span > < span class = "n" > trans_a< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "c1" > # compute dp = dot(v, do)< / span >
< span class = "n" > Di< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > D_ptrs< / span > < span class = "o" > +< / span > < span class = "n" > offs_m_curr< / span > < span class = "p" > )< / span >
< span class = "n" > dp< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_N< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span > < span class = "o" > -< / span > < span class = "n" > Di< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span >
< span class = "n" > dp< / span > < span class = "o" > +=< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > dot< / span > < span class = "p" > (< / span > < span class = "n" > do< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "p" > ,< / span > < span class = "n" > trans_b< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "c1" > # compute ds = p * (dp - delta[:, None])< / span >
< span class = "n" > ds< / span > < span class = "o" > =< / span > < span class = "n" > p< / span > < span class = "o" > *< / span > < span class = "n" > dp< / span > < span class = "o" > *< / span > < span class = "n" > sm_scale< / span >
< span class = "c1" > # compute dk = dot(ds.T, q)< / span >
< span class = "n" > dk< / span > < span class = "o" > +=< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > dot< / span > < span class = "p" > (< / span > < span class = "n" > ds< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float16< / span > < span class = "p" > ),< / span > < span class = "n" > q< / span > < span class = "p" > ,< / span > < span class = "n" > trans_a< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "c1" > # # compute dq< / span >
< span class = "n" > dq< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > dq_ptrs< / span > < span class = "p" > ,< / span > < span class = "n" > eviction_policy< / span > < span class = "o" > =< / span > < span class = "s2" > " evict_last" < / span > < span class = "p" > )< / span >
< span class = "n" > dq< / span > < span class = "o" > +=< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > dot< / span > < span class = "p" > (< / span > < span class = "n" > ds< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float16< / span > < span class = "p" > ),< / span > < span class = "n" > k< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > dq_ptrs< / span > < span class = "p" > ,< / span > < span class = "n" > dq< / span > < span class = "p" > ,< / span > < span class = "n" > eviction_policy< / span > < span class = "o" > =< / span > < span class = "s2" > " evict_last" < / span > < span class = "p" > )< / span >
< span class = "c1" > # # increment pointers< / span >
< span class = "n" > dq_ptrs< / span > < span class = "o" > +=< / span > < span class = "n" > BLOCK_M< / span > < span class = "o" > *< / span > < span class = "n" > stride_qm< / span >
< span class = "n" > q_ptrs< / span > < span class = "o" > +=< / span > < span class = "n" > BLOCK_M< / span > < span class = "o" > *< / span > < span class = "n" > stride_qm< / span >
< span class = "n" > do_ptrs< / span > < span class = "o" > +=< / span > < span class = "n" > BLOCK_M< / span > < span class = "o" > *< / span > < span class = "n" > stride_qm< / span >
< span class = "c1" > # write-back< / span >
< span class = "n" > dv_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > DV< / span > < span class = "o" > +< / span > < span class = "p" > (< / span > < span class = "n" > offs_n< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > stride_qm< / span > < span class = "o" > +< / span > < span class = "n" > offs_k< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span > < span class = "o" > *< / span > < span class = "n" > stride_qk< / span > < span class = "p" > )< / span >
< span class = "n" > dk_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > DK< / span > < span class = "o" > +< / span > < span class = "p" > (< / span > < span class = "n" > offs_n< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > stride_kn< / span > < span class = "o" > +< / span > < span class = "n" > offs_k< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span > < span class = "o" > *< / span > < span class = "n" > stride_kk< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > dv_ptrs< / span > < span class = "p" > ,< / span > < span class = "n" > dv< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > dk_ptrs< / span > < span class = "p" > ,< / span > < span class = "n" > dk< / span > < span class = "p" > )< / span >
< span class = "k" > class< / span > < span class = "nc" > _attention< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > autograd< / span > < span class = "o" > .< / span > < span class = "n" > Function< / span > < span class = "p" > ):< / span >
< span class = "nd" > @staticmethod< / span >
< span class = "k" > def< / span > < span class = "nf" > forward< / span > < span class = "p" > (< / span > < span class = "n" > ctx< / span > < span class = "p" > ,< / span > < span class = "n" > q< / span > < span class = "p" > ,< / span > < span class = "n" > k< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "p" > ,< / span > < span class = "n" > sm_scale< / span > < span class = "p" > ):< / span >
< span class = "n" > BLOCK< / span > < span class = "o" > =< / span > < span class = "mi" > 128< / span >
< span class = "c1" > # shape constraints< / span >
2022-08-17 00:49:36 +00:00
< span class = "n" > Lq< / span > < span class = "p" > ,< / span > < span class = "n" > Lk< / span > < span class = "p" > ,< / span > < span class = "n" > Lv< / span > < span class = "o" > =< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "n" > v< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ]< / span >
< span class = "k" > assert< / span > < span class = "n" > Lq< / span > < span class = "o" > ==< / span > < span class = "n" > Lk< / span > < span class = "ow" > and< / span > < span class = "n" > Lk< / span > < span class = "o" > ==< / span > < span class = "n" > Lv< / span >
< span class = "k" > assert< / span > < span class = "n" > Lk< / span > < span class = "ow" > in< / span > < span class = "p" > {< / span > < span class = "mi" > 16< / span > < span class = "p" > ,< / span > < span class = "mi" > 32< / span > < span class = "p" > ,< / span > < span class = "mi" > 64< / span > < span class = "p" > ,< / span > < span class = "mi" > 128< / span > < span class = "p" > }< / span >
2022-07-14 07:22:19 +00:00
< span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > q< / span > < span class = "p" > )< / span >
< span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > cdiv< / span > < span class = "p" > (< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 2< / span > < span class = "p" > ],< / span > < span class = "n" > BLOCK< / span > < span class = "p" > ),< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ])< / span >
< span class = "n" > tmp< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 2< / span > < span class = "p" > ]),< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > L< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 2< / span > < span class = "p" > ]),< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > m< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 2< / span > < span class = "p" > ]),< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
2022-08-17 00:49:36 +00:00
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "mi" > 4< / span > < span class = "k" > if< / span > < span class = "n" > Lk< / span > < span class = "o" > < =< / span > < span class = "mi" > 64< / span > < span class = "k" > else< / span > < span class = "mi" > 8< / span >
2022-07-14 07:22:19 +00:00
< span class = "n" > _fwd_kernel< / span > < span class = "p" > [< / span > < span class = "n" > grid< / span > < span class = "p" > ](< / span >
< span class = "n" > q< / span > < span class = "p" > ,< / span > < span class = "n" > k< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "p" > ,< / span > < span class = "n" > sm_scale< / span > < span class = "p" > ,< / span >
< span class = "n" > tmp< / span > < span class = "p" > ,< / span > < span class = "n" > L< / span > < span class = "p" > ,< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span >
< span class = "n" > o< / span > < span class = "p" > ,< / span >
< span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ),< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 3< / span > < span class = "p" > ),< / span >
< span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span > < span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ),< / span > < span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 3< / span > < span class = "p" > ),< / span >
< span class = "n" > v< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > v< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span > < span class = "n" > v< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ),< / span > < span class = "n" > v< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 3< / span > < span class = "p" > ),< / span >
< span class = "n" > o< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > o< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span > < span class = "n" > o< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ),< / span > < span class = "n" > o< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 3< / span > < span class = "p" > ),< / span >
< span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ],< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 2< / span > < span class = "p" > ],< / span >
< span class = "n" > BLOCK_M< / span > < span class = "o" > =< / span > < span class = "n" > BLOCK< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_N< / span > < span class = "o" > =< / span > < span class = "n" > BLOCK< / span > < span class = "p" > ,< / span >
2022-08-17 00:49:36 +00:00
< span class = "n" > BLOCK_DMODEL< / span > < span class = "o" > =< / span > < span class = "n" > Lk< / span > < span class = "p" > ,< / span > < span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "n" > num_warps< / span > < span class = "p" > ,< / span >
2022-07-14 07:22:19 +00:00
< span class = "n" > num_stages< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > save_for_backward< / span > < span class = "p" > (< / span > < span class = "n" > q< / span > < span class = "p" > ,< / span > < span class = "n" > k< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "p" > ,< / span > < span class = "n" > o< / span > < span class = "p" > ,< / span > < span class = "n" > L< / span > < span class = "p" > ,< / span > < span class = "n" > m< / span > < span class = "p" > )< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK< / span > < span class = "o" > =< / span > < span class = "n" > BLOCK< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "n" > grid< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > sm_scale< / span > < span class = "o" > =< / span > < span class = "n" > sm_scale< / span >
2022-08-17 00:49:36 +00:00
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK_DMODEL< / span > < span class = "o" > =< / span > < span class = "n" > Lk< / span >
2022-07-14 07:22:19 +00:00
< span class = "k" > return< / span > < span class = "n" > o< / span >
< span class = "nd" > @staticmethod< / span >
< span class = "k" > def< / span > < span class = "nf" > backward< / span > < span class = "p" > (< / span > < span class = "n" > ctx< / span > < span class = "p" > ,< / span > < span class = "n" > do< / span > < span class = "p" > ):< / span >
< span class = "n" > q< / span > < span class = "p" > ,< / span > < span class = "n" > k< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "p" > ,< / span > < span class = "n" > o< / span > < span class = "p" > ,< / span > < span class = "n" > l< / span > < span class = "p" > ,< / span > < span class = "n" > m< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > saved_tensors< / span >
< span class = "n" > do< / span > < span class = "o" > =< / span > < span class = "n" > do< / span > < span class = "o" > .< / span > < span class = "n" > contiguous< / span > < span class = "p" > ()< / span >
< span class = "n" > dq< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > zeros_like< / span > < span class = "p" > (< / span > < span class = "n" > q< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > dk< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > k< / span > < span class = "p" > )< / span >
< span class = "n" > dv< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > v< / span > < span class = "p" > )< / span >
< span class = "n" > do_scaled< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > do< / span > < span class = "p" > )< / span >
< span class = "n" > delta< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > l< / span > < span class = "p" > )< / span >
< span class = "n" > _bwd_preprocess< / span > < span class = "p" > [(< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > grid< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > grid< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "p" > )](< / span >
< span class = "n" > o< / span > < span class = "p" > ,< / span > < span class = "n" > do< / span > < span class = "p" > ,< / span > < span class = "n" > l< / span > < span class = "p" > ,< / span >
< span class = "n" > do_scaled< / span > < span class = "p" > ,< / span > < span class = "n" > delta< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_M< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK< / span > < span class = "p" > ,< / span > < span class = "n" > D_HEAD< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK_DMODEL< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
2022-08-17 00:49:36 +00:00
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "mi" > 4< / span > < span class = "k" > if< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK_DMODEL< / span > < span class = "o" > < =< / span > < span class = "mi" > 64< / span > < span class = "k" > else< / span > < span class = "mi" > 8< / span >
2022-07-14 07:22:19 +00:00
< span class = "n" > _bwd_kernel< / span > < span class = "p" > [(< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > grid< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ],)](< / span >
< span class = "n" > q< / span > < span class = "p" > ,< / span > < span class = "n" > k< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "p" > ,< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > sm_scale< / span > < span class = "p" > ,< / span >
< span class = "n" > o< / span > < span class = "p" > ,< / span > < span class = "n" > do_scaled< / span > < span class = "p" > ,< / span >
< span class = "n" > dq< / span > < span class = "p" > ,< / span > < span class = "n" > dk< / span > < span class = "p" > ,< / span > < span class = "n" > dv< / span > < span class = "p" > ,< / span >
< span class = "n" > l< / span > < span class = "p" > ,< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span >
< span class = "n" > delta< / span > < span class = "p" > ,< / span >
< span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ),< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 3< / span > < span class = "p" > ),< / span >
< span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span > < span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ),< / span > < span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 3< / span > < span class = "p" > ),< / span >
< span class = "n" > v< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > v< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span > < span class = "n" > v< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ),< / span > < span class = "n" > v< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 3< / span > < span class = "p" > ),< / span >
< span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ],< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 2< / span > < span class = "p" > ],< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > grid< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ],< / span >
< span class = "n" > BLOCK_M< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_N< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK< / span > < span class = "p" > ,< / span >
2022-08-17 00:49:36 +00:00
< span class = "n" > BLOCK_DMODEL< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK_DMODEL< / span > < span class = "p" > ,< / span > < span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "n" > num_warps< / span > < span class = "p" > ,< / span >
2022-07-14 07:22:19 +00:00
< span class = "n" > num_stages< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > dq< / span > < span class = "p" > ,< / span > < span class = "n" > dk< / span > < span class = "p" > ,< / span > < span class = "n" > dv< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span >
< span class = "n" > attention< / span > < span class = "o" > =< / span > < span class = "n" > _attention< / span > < span class = "o" > .< / span > < span class = "n" > apply< / span >
< span class = "nd" > @pytest< / span > < span class = "o" > .< / span > < span class = "n" > mark< / span > < span class = "o" > .< / span > < span class = "n" > parametrize< / span > < span class = "p" > (< / span > < span class = "s1" > ' Z, H, N_CTX, D_HEAD' < / span > < span class = "p" > ,< / span > < span class = "p" > [(< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 2048< / span > < span class = "p" > ,< / span > < span class = "mi" > 64< / span > < span class = "p" > )])< / span >
< span class = "k" > def< / span > < span class = "nf" > test_op< / span > < span class = "p" > (< / span > < span class = "n" > Z< / span > < span class = "p" > ,< / span > < span class = "n" > H< / span > < span class = "p" > ,< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span > < span class = "n" > D_HEAD< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float16< / span > < span class = "p" > ):< / span >
< span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > manual_seed< / span > < span class = "p" > (< / span > < span class = "mi" > 20< / span > < span class = "p" > )< / span >
< span class = "n" > q< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > Z< / span > < span class = "p" > ,< / span > < span class = "n" > H< / span > < span class = "p" > ,< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span > < span class = "n" > D_HEAD< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s2" > " cuda" < / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > normal_< / span > < span class = "p" > (< / span > < span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > std< / span > < span class = "o" > =< / span > < span class = "mf" > .5< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > requires_grad_< / span > < span class = "p" > ()< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > Z< / span > < span class = "p" > ,< / span > < span class = "n" > H< / span > < span class = "p" > ,< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span > < span class = "n" > D_HEAD< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s2" > " cuda" < / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > normal_< / span > < span class = "p" > (< / span > < span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > std< / span > < span class = "o" > =< / span > < span class = "mf" > .5< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > requires_grad_< / span > < span class = "p" > ()< / span >
< span class = "n" > v< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > Z< / span > < span class = "p" > ,< / span > < span class = "n" > H< / span > < span class = "p" > ,< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span > < span class = "n" > D_HEAD< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s2" > " cuda" < / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > normal_< / span > < span class = "p" > (< / span > < span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > std< / span > < span class = "o" > =< / span > < span class = "mf" > .5< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > requires_grad_< / span > < span class = "p" > ()< / span >
< span class = "n" > sm_scale< / span > < span class = "o" > =< / span > < span class = "mf" > 0.3< / span >
< span class = "n" > dout< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn_like< / span > < span class = "p" > (< / span > < span class = "n" > q< / span > < span class = "p" > )< / span >
< span class = "c1" > # reference implementation< / span >
< span class = "n" > M< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > tril< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > ones< / span > < span class = "p" > ((< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ),< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s2" > " cuda" < / span > < span class = "p" > ))< / span >
< span class = "n" > p< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > matmul< / span > < span class = "p" > (< / span > < span class = "n" > q< / span > < span class = "p" > ,< / span > < span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > transpose< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 3< / span > < span class = "p" > ))< / span > < span class = "o" > *< / span > < span class = "n" > sm_scale< / span >
< span class = "k" > for< / span > < span class = "n" > z< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "n" > Z< / span > < span class = "p" > ):< / span >
< span class = "k" > for< / span > < span class = "n" > h< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "n" > H< / span > < span class = "p" > ):< / span >
< span class = "n" > p< / span > < span class = "p" > [:,< / span > < span class = "p" > :,< / span > < span class = "n" > M< / span > < span class = "o" > ==< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span > < span class = "o" > =< / span > < span class = "nb" > float< / span > < span class = "p" > (< / span > < span class = "s2" > " -inf" < / span > < span class = "p" > )< / span >
< span class = "n" > p< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > softmax< / span > < span class = "p" > (< / span > < span class = "n" > p< / span > < span class = "o" > .< / span > < span class = "n" > float< / span > < span class = "p" > (),< / span > < span class = "n" > dim< / span > < span class = "o" > =-< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > half< / span > < span class = "p" > ()< / span >
< span class = "n" > ref_out< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > matmul< / span > < span class = "p" > (< / span > < span class = "n" > p< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "p" > )< / span >
< span class = "n" > ref_out< / span > < span class = "o" > .< / span > < span class = "n" > backward< / span > < span class = "p" > (< / span > < span class = "n" > dout< / span > < span class = "p" > )< / span >
< span class = "n" > ref_dv< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > =< / span > < span class = "n" > v< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > .< / span > < span class = "n" > clone< / span > < span class = "p" > (),< / span > < span class = "kc" > None< / span >
< span class = "n" > ref_dk< / span > < span class = "p" > ,< / span > < span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > =< / span > < span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > .< / span > < span class = "n" > clone< / span > < span class = "p" > (),< / span > < span class = "kc" > None< / span >
< span class = "n" > ref_dq< / span > < span class = "p" > ,< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > =< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > .< / span > < span class = "n" > clone< / span > < span class = "p" > (),< / span > < span class = "kc" > None< / span >
< span class = "c1" > # triton implementation< / span >
< span class = "n" > tri_out< / span > < span class = "o" > =< / span > < span class = "n" > attention< / span > < span class = "p" > (< / span > < span class = "n" > q< / span > < span class = "p" > ,< / span > < span class = "n" > k< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "p" > ,< / span > < span class = "n" > sm_scale< / span > < span class = "p" > )< / span >
< span class = "n" > tri_out< / span > < span class = "o" > .< / span > < span class = "n" > backward< / span > < span class = "p" > (< / span > < span class = "n" > dout< / span > < span class = "p" > )< / span >
< span class = "n" > tri_dv< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > =< / span > < span class = "n" > v< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > .< / span > < span class = "n" > clone< / span > < span class = "p" > (),< / span > < span class = "kc" > None< / span >
< span class = "n" > tri_dk< / span > < span class = "p" > ,< / span > < span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > =< / span > < span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > .< / span > < span class = "n" > clone< / span > < span class = "p" > (),< / span > < span class = "kc" > None< / span >
< span class = "n" > tri_dq< / span > < span class = "p" > ,< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > =< / span > < span class = "n" > q< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > .< / span > < span class = "n" > clone< / span > < span class = "p" > (),< / span > < span class = "kc" > None< / span >
< span class = "c1" > # compare< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > ref_out< / span > < span class = "p" > ,< / span > < span class = "n" > tri_out< / span > < span class = "p" > )< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > ref_dv< / span > < span class = "p" > ,< / span > < span class = "n" > tri_dv< / span > < span class = "p" > )< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > ref_dk< / span > < span class = "p" > ,< / span > < span class = "n" > tri_dk< / span > < span class = "p" > )< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > ref_dq< / span > < span class = "p" > ,< / span > < span class = "n" > tri_dq< / span > < span class = "p" > )< / span >
< span class = "k" > try< / span > < span class = "p" > :< / span >
< span class = "kn" > from< / span > < span class = "nn" > flash_attn.flash_attn_interface< / span > < span class = "kn" > import< / span > < span class = "n" > flash_attn_func< / span >
< span class = "n" > HAS_FLASH< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span >
< span class = "k" > except< / span > < span class = "ne" > BaseException< / span > < span class = "p" > :< / span >
< span class = "n" > HAS_FLASH< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span >
< span class = "n" > BATCH< / span > < span class = "p" > ,< / span > < span class = "n" > N_HEADS< / span > < span class = "p" > ,< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span > < span class = "n" > D_HEAD< / span > < span class = "o" > =< / span > < span class = "mi" > 4< / span > < span class = "p" > ,< / span > < span class = "mi" > 48< / span > < span class = "p" > ,< / span > < span class = "mi" > 4096< / span > < span class = "p" > ,< / span > < span class = "mi" > 64< / span >
< span class = "c1" > # vary seq length for fixed head and batch=4< / span >
< span class = "n" > configs< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > Benchmark< / span > < span class = "p" > (< / span >
< span class = "n" > x_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' N_CTX' < / span > < span class = "p" > ],< / span >
< span class = "n" > x_vals< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "mi" > 2< / span > < span class = "o" > **< / span > < span class = "n" > i< / span > < span class = "k" > for< / span > < span class = "n" > i< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 10< / span > < span class = "p" > ,< / span > < span class = "mi" > 16< / span > < span class = "p" > )],< / span >
< span class = "n" > line_arg< / span > < span class = "o" > =< / span > < span class = "s1" > ' provider' < / span > < span class = "p" > ,< / span >
< span class = "n" > line_vals< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' triton' < / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "p" > ([< / span > < span class = "s1" > ' flash' < / span > < span class = "p" > ]< / span > < span class = "k" > if< / span > < span class = "n" > HAS_FLASH< / span > < span class = "k" > else< / span > < span class = "p" > []),< / span >
< span class = "n" > line_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' Triton' < / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "p" > ([< / span > < span class = "s1" > ' Flash' < / span > < span class = "p" > ]< / span > < span class = "k" > if< / span > < span class = "n" > HAS_FLASH< / span > < span class = "k" > else< / span > < span class = "p" > []),< / span >
< span class = "n" > styles< / span > < span class = "o" > =< / span > < span class = "p" > [(< / span > < span class = "s1" > ' red' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' -' < / span > < span class = "p" > ),< / span > < span class = "p" > (< / span > < span class = "s1" > ' blue' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' -' < / span > < span class = "p" > )],< / span >
< span class = "n" > ylabel< / span > < span class = "o" > =< / span > < span class = "s1" > ' ms' < / span > < span class = "p" > ,< / span >
< span class = "n" > plot_name< / span > < span class = "o" > =< / span > < span class = "sa" > f< / span > < span class = "s1" > ' fused-attention-batch< / span > < span class = "si" > {< / span > < span class = "n" > BATCH< / span > < span class = "si" > }< / span > < span class = "s1" > -head< / span > < span class = "si" > {< / span > < span class = "n" > N_HEADS< / span > < span class = "si" > }< / span > < span class = "s1" > -d< / span > < span class = "si" > {< / span > < span class = "n" > D_HEAD< / span > < span class = "si" > }< / span > < span class = "s1" > -< / span > < span class = "si" > {< / span > < span class = "n" > mode< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span > < span class = "p" > ,< / span >
< span class = "n" > args< / span > < span class = "o" > =< / span > < span class = "p" > {< / span > < span class = "s1" > ' H' < / span > < span class = "p" > :< / span > < span class = "n" > N_HEADS< / span > < span class = "p" > ,< / span > < span class = "s1" > ' BATCH' < / span > < span class = "p" > :< / span > < span class = "n" > BATCH< / span > < span class = "p" > ,< / span > < span class = "s1" > ' D_HEAD' < / span > < span class = "p" > :< / span > < span class = "n" > D_HEAD< / span > < span class = "p" > ,< / span > < span class = "s1" > ' dtype' < / span > < span class = "p" > :< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float16< / span > < span class = "p" > ,< / span > < span class = "s1" > ' mode' < / span > < span class = "p" > :< / span > < span class = "n" > mode< / span > < span class = "p" > }< / span >
< span class = "p" > )< / span > < span class = "k" > for< / span > < span class = "n" > mode< / span > < span class = "ow" > in< / span > < span class = "p" > [< / span > < span class = "s1" > ' bwd' < / span > < span class = "p" > ]]< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > perf_report< / span > < span class = "p" > (< / span > < span class = "n" > configs< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > bench_flash_attention< / span > < span class = "p" > (< / span > < span class = "n" > BATCH< / span > < span class = "p" > ,< / span > < span class = "n" > H< / span > < span class = "p" > ,< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span > < span class = "n" > D_HEAD< / span > < span class = "p" > ,< / span > < span class = "n" > mode< / span > < span class = "p" > ,< / span > < span class = "n" > provider< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float16< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s2" > " cuda" < / span > < span class = "p" > ):< / span >
< span class = "k" > assert< / span > < span class = "n" > mode< / span > < span class = "ow" > in< / span > < span class = "p" > [< / span > < span class = "s1" > ' fwd' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' bwd' < / span > < span class = "p" > ]< / span >
< span class = "n" > warmup< / span > < span class = "o" > =< / span > < span class = "mi" > 25< / span >
< span class = "n" > rep< / span > < span class = "o" > =< / span > < span class = "mi" > 100< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s2" > " triton" < / span > < span class = "p" > :< / span >
< span class = "n" > q< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > ((< / span > < span class = "n" > BATCH< / span > < span class = "p" > ,< / span > < span class = "n" > H< / span > < span class = "p" > ,< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span > < span class = "n" > D_HEAD< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s2" > " cuda" < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > ((< / span > < span class = "n" > BATCH< / span > < span class = "p" > ,< / span > < span class = "n" > H< / span > < span class = "p" > ,< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span > < span class = "n" > D_HEAD< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s2" > " cuda" < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > v< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > ((< / span > < span class = "n" > BATCH< / span > < span class = "p" > ,< / span > < span class = "n" > H< / span > < span class = "p" > ,< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span > < span class = "n" > D_HEAD< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s2" > " cuda" < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > sm_scale< / span > < span class = "o" > =< / span > < span class = "mf" > 1.3< / span >
< span class = "n" > fn< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > attention< / span > < span class = "p" > (< / span > < span class = "n" > q< / span > < span class = "p" > ,< / span > < span class = "n" > k< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "p" > ,< / span > < span class = "n" > sm_scale< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > mode< / span > < span class = "o" > ==< / span > < span class = "s1" > ' bwd' < / span > < span class = "p" > :< / span >
< span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > fn< / span > < span class = "p" > ()< / span >
< span class = "n" > do< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn_like< / span > < span class = "p" > (< / span > < span class = "n" > o< / span > < span class = "p" > )< / span >
< span class = "n" > fn< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > o< / span > < span class = "o" > .< / span > < span class = "n" > backward< / span > < span class = "p" > (< / span > < span class = "n" > do< / span > < span class = "p" > ,< / span > < span class = "n" > retain_graph< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > ms< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > do_bench< / span > < span class = "p" > (< / span > < span class = "n" > fn< / span > < span class = "p" > ,< / span > < span class = "n" > percentiles< / span > < span class = "o" > =< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "n" > warmup< / span > < span class = "o" > =< / span > < span class = "n" > warmup< / span > < span class = "p" > ,< / span > < span class = "n" > rep< / span > < span class = "o" > =< / span > < span class = "n" > rep< / span > < span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > ms< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s2" > " flash" < / span > < span class = "p" > :< / span >
< span class = "n" > lengths< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > full< / span > < span class = "p" > ((< / span > < span class = "n" > BATCH< / span > < span class = "p" > ,),< / span > < span class = "n" > fill_value< / span > < span class = "o" > =< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > device< / span > < span class = "p" > )< / span >
< span class = "n" > cu_seqlens< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ((< / span > < span class = "n" > BATCH< / span > < span class = "o" > +< / span > < span class = "mi" > 1< / span > < span class = "p" > ,),< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > device< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > int32< / span > < span class = "p" > )< / span >
< span class = "n" > cu_seqlens< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > :]< / span > < span class = "o" > =< / span > < span class = "n" > lengths< / span > < span class = "o" > .< / span > < span class = "n" > cumsum< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > qkv< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > ((< / span > < span class = "n" > BATCH< / span > < span class = "o" > *< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "n" > H< / span > < span class = "p" > ,< / span > < span class = "n" > D_HEAD< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > device< / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > fn< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > flash_attn_func< / span > < span class = "p" > (< / span > < span class = "n" > qkv< / span > < span class = "p" > ,< / span > < span class = "n" > cu_seqlens< / span > < span class = "p" > ,< / span > < span class = "mf" > 0.< / span > < span class = "p" > ,< / span > < span class = "n" > N_CTX< / span > < span class = "p" > ,< / span > < span class = "n" > causal< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > mode< / span > < span class = "o" > ==< / span > < span class = "s1" > ' bwd' < / span > < span class = "p" > :< / span >
< span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > fn< / span > < span class = "p" > ()< / span >
< span class = "n" > do< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn_like< / span > < span class = "p" > (< / span > < span class = "n" > o< / span > < span class = "p" > )< / span >
< span class = "n" > fn< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > o< / span > < span class = "o" > .< / span > < span class = "n" > backward< / span > < span class = "p" > (< / span > < span class = "n" > do< / span > < span class = "p" > ,< / span > < span class = "n" > retain_graph< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > ms< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > do_bench< / span > < span class = "p" > (< / span > < span class = "n" > fn< / span > < span class = "p" > ,< / span > < span class = "n" > percentiles< / span > < span class = "o" > =< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "n" > warmup< / span > < span class = "o" > =< / span > < span class = "n" > warmup< / span > < span class = "p" > ,< / span > < span class = "n" > rep< / span > < span class = "o" > =< / span > < span class = "n" > rep< / span > < span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > ms< / span >
< span class = "c1" > # only works on A100 at the moment< / span >
< span class = "c1" > # bench_flash_attention.run(save_path=' .' , print_data=True)< / span >
< / pre > < / div >
< / div >
2022-09-13 00:54:01 +00:00
< p class = "sphx-glr-timing" > < strong > Total running time of the script:< / strong > ( 0 minutes 0.076 seconds)< / p >
2022-07-14 07:22:19 +00:00
< div class = "sphx-glr-footer class sphx-glr-footer-example docutils container" id = "sphx-glr-download-getting-started-tutorials-06-fused-attention-py" >
< div class = "sphx-glr-download sphx-glr-download-python docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/54a35f6ec55f9746935b9566fb6bb1df/06-fused-attention.py" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Python< / span > < span class = "pre" > source< / span > < span class = "pre" > code:< / span > < span class = "pre" > 06-fused-attention.py< / span > < / code > < / a > < / p >
< / div >
< div class = "sphx-glr-download sphx-glr-download-jupyter docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Jupyter< / span > < span class = "pre" > notebook:< / span > < span class = "pre" > 06-fused-attention.ipynb< / span > < / code > < / a > < / p >
< / div >
< / div >
< p class = "sphx-glr-signature" > < a class = "reference external" href = "https://sphinx-gallery.github.io" > Gallery generated by Sphinx-Gallery< / a > < / p >
< / div >
< / div >
< / div >
< footer >
< div class = "rst-footer-buttons" role = "navigation" aria-label = "footer navigation" >
< a href = "07-libdevice-function.html" class = "btn btn-neutral float-right" title = "Libdevice function" accesskey = "n" rel = "next" > Next < span class = "fa fa-arrow-circle-right" aria-hidden = "true" > < / span > < / a >
< a href = "05-layer-norm.html" class = "btn btn-neutral float-left" title = "Layer Normalization" accesskey = "p" rel = "prev" > < span class = "fa fa-arrow-circle-left" aria-hidden = "true" > < / span > Previous< / a >
< / div >
< hr / >
< div role = "contentinfo" >
< p >
© Copyright 2020, Philippe Tillet.
< / p >
< / div >
Built with < a href = "https://www.sphinx-doc.org/" > Sphinx< / a > using a
< a href = "https://github.com/readthedocs/sphinx_rtd_theme" > theme< / a >
provided by < a href = "https://readthedocs.org" > Read the Docs< / a > .
< / footer >
< / div >
< / div >
< / section >
< / div >
< div class = "rst-versions" data-toggle = "rst-versions" role = "note" aria-label = "versions" >
< span class = "rst-current-version" data-toggle = "rst-current-version" >
< span class = "fa fa-book" > Other Versions< / span >
v: master
< span class = "fa fa-caret-down" > < / span >
< / span >
< div class = "rst-other-versions" >
< dl >
< dt > Tags< / dt >
< dd > < a href = "../../../v1.1.2/index.html" > v1.1.2< / a > < / dd >
< / dl >
< dl >
< dt > Branches< / dt >
< dd > < a href = "06-fused-attention.html" > master< / a > < / dd >
< / dl >
< / div >
< / div >
< script type = "text/javascript" >
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
< / script >
< / body >
< / html >