2022-06-05 21:05:02 +00:00
<!DOCTYPE html>
< html class = "writer-html5" lang = "en" >
< head >
< meta charset = "utf-8" / >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" / >
< title > Layer Normalization — Triton documentation< / title >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-binder.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-dataframe.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-rendered-html.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/css/custom.css" type = "text/css" / >
<!-- [if lt IE 9]>
< script src = "../../_static/js/html5shiv.min.js" > < / script >
<![endif]-->
< script type = "text/javascript" id = "documentation_options" data-url_root = "../../" src = "../../_static/documentation_options.js" > < / script >
< script data-url_root = "../../" id = "documentation_options" src = "../../_static/documentation_options.js" > < / script >
< script src = "../../_static/jquery.js" > < / script >
< script src = "../../_static/underscore.js" > < / script >
< script src = "../../_static/doctools.js" > < / script >
< script type = "text/javascript" src = "../../_static/js/theme.js" > < / script >
< link rel = "index" title = "Index" href = "../../genindex.html" / >
< link rel = "search" title = "Search" href = "../../search.html" / >
< link rel = "next" title = "triton" href = "../../python-api/triton.html" / >
< link rel = "prev" title = "Low-Memory Dropout" href = "04-low-memory-dropout.html" / >
< / head >
< body class = "wy-body-for-nav" >
< div class = "wy-grid-for-nav" >
< nav data-toggle = "wy-nav-shift" class = "wy-nav-side" >
< div class = "wy-side-scroll" >
< div class = "wy-side-nav-search" >
< a href = "../../index.html" class = "icon icon-home" > Triton
< / a >
< div role = "search" >
< form id = "rtd-search-form" class = "wy-form" action = "../../search.html" method = "get" >
< input type = "text" name = "q" placeholder = "Search docs" / >
< input type = "hidden" name = "check_keywords" value = "yes" / >
< input type = "hidden" name = "area" value = "default" / >
< / form >
< / div >
< / div >
< div class = "wy-menu wy-menu-vertical" data-spy = "affix" role = "navigation" aria-label = "main navigation" >
< p class = "caption" role = "heading" > < span class = "caption-text" > Getting Started< / span > < / p >
< ul class = "current" >
< li class = "toctree-l1" > < a class = "reference internal" href = "../installation.html" > Installation< / a > < / li >
< li class = "toctree-l1 current" > < a class = "reference internal" href = "index.html" > Tutorials< / a > < ul class = "current" >
< li class = "toctree-l2" > < a class = "reference internal" href = "01-vector-add.html" > Vector Addition< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "02-fused-softmax.html" > Fused Softmax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "03-matrix-multiplication.html" > Matrix Multiplication< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "04-low-memory-dropout.html" > Low-Memory Dropout< / a > < / li >
< li class = "toctree-l2 current" > < a class = "current reference internal" href = "#" > Layer Normalization< / a > < / li >
< / ul >
< / li >
< / ul >
< p class = "caption" role = "heading" > < span class = "caption-text" > Python API< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.html" > triton< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.language.html" > triton.language< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.testing.html" > triton.testing< / a > < / li >
< / ul >
< p class = "caption" role = "heading" > < span class = "caption-text" > Programming Guide< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-1/introduction.html" > Introduction< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-2/related-work.html" > Related Work< / a > < / li >
< / ul >
< / div >
< / div >
< / nav >
< section data-toggle = "wy-nav-shift" class = "wy-nav-content-wrap" >
< nav class = "wy-nav-top" aria-label = "top navigation" >
< i data-toggle = "wy-nav-top" class = "fa fa-bars" > < / i >
< a href = "../../index.html" > Triton< / a >
< / nav >
< div class = "wy-nav-content" >
< div class = "rst-content" >
< div role = "navigation" aria-label = "breadcrumbs navigation" >
< ul class = "wy-breadcrumbs" >
< li > < a href = "../../index.html" class = "icon icon-home" > < / a > » < / li >
< li > < a href = "index.html" > Tutorials< / a > » < / li >
< li > Layer Normalization< / li >
< li class = "wy-breadcrumbs-aside" >
< a href = "../../_sources/getting-started/tutorials/05-layer-norm.rst.txt" rel = "nofollow" > View page source< / a >
< / li >
< / ul >
< hr / >
< / div >
< div role = "main" class = "document" itemscope = "itemscope" itemtype = "http://schema.org/Article" >
< div itemprop = "articleBody" >
< div class = "sphx-glr-download-link-note admonition note" >
< p class = "admonition-title" > Note< / p >
< p > Click < a class = "reference internal" href = "#sphx-glr-download-getting-started-tutorials-05-layer-norm-py" > < span class = "std std-ref" > here< / span > < / a >
to download the full example code< / p >
< / div >
< div class = "sphx-glr-example-title section" id = "layer-normalization" >
< span id = "sphx-glr-getting-started-tutorials-05-layer-norm-py" > < / span > < h1 > Layer Normalization< a class = "headerlink" href = "#layer-normalization" title = "Permalink to this headline" > ¶< / a > < / h1 >
< img alt = "05 layer norm" class = "sphx-glr-single-img" src = "../../_images/sphx_glr_05-layer-norm_001.png" / >
< p class = "sphx-glr-script-out" > Out:< / p >
< div class = "sphx-glr-script-out highlight-none notranslate" > < div class = "highlight" > < pre > < span > < / span > layer-norm:
N Triton Torch Apex
2022-06-18 00:47:57 +00:00
0 1024.0 585.142849 277.694907 481.882344
2022-06-05 21:05:02 +00:00
1 1536.0 630.153868 323.368435 511.999982
2022-06-18 00:47:57 +00:00
2 2048.0 682.666643 337.814445 520.126988
3 2560.0 694.237267 362.477870 512.000013
2022-06-16 00:46:38 +00:00
4 3072.0 712.347810 378.092307 501.551037
2022-06-18 00:47:57 +00:00
5 3584.0 725.873439 384.859062 451.527536
6 4096.0 728.177767 381.023256 451.972420
7 4608.0 676.403666 396.387087 428.651163
8 5120.0 688.403381 395.748783 420.102563
9 5632.0 709.543270 395.228063 415.262685
2022-06-13 00:48:38 +00:00
10 6144.0 702.171410 402.885254 411.313806
2022-06-18 00:47:57 +00:00
11 6656.0 700.631610 400.360920 400.360920
12 7168.0 690.891575 388.772874 384.859062
13 7680.0 682.666656 392.587863 386.415087
14 8192.0 639.375598 390.095241 370.259899
15 8704.0 624.502255 389.005597 379.465939
16 9216.0 606.814809 406.214877 382.010363
17 9728.0 587.350922 408.524944 382.427505
18 10240.0 566.920437 409.600010 382.803739
2022-06-13 00:48:38 +00:00
19 10752.0 549.623009 411.559798 381.445676
2022-06-18 00:47:57 +00:00
20 11264.0 534.789310 403.185684 371.595879
21 11776.0 523.377770 410.492372 376.831982
22 12288.0 518.754611 413.911572 383.251457
23 12800.0 505.679014 409.599981 377.163903
24 13312.0 495.330249 405.699062 376.976995
25 13824.0 482.934503 412.656711 379.389355
26 14336.0 471.967074 403.830973 371.158581
27 14848.0 461.297068 406.794504 374.712936
28 15360.0 454.269882 406.887417 378.092307
29 15872.0 447.098578 406.974373 375.668625
2022-06-05 21:05:02 +00:00
< / pre > < / div >
< / div >
< div class = "line-block" >
< div class = "line" > < br / > < / div >
< / div >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton.language< / span > < span class = "k" > as< / span > < span class = "nn" > tl< / span >
< span class = "k" > try< / span > < span class = "p" > :< / span >
< span class = "c1" > # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it< / span >
< span class = "c1" > # should not be added to extras_require in setup.py.< / span >
< span class = "kn" > import< / span > < span class = "nn" > apex< / span >
< span class = "n" > HAS_APEX< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span >
< span class = "k" > except< / span > < span class = "ne" > ModuleNotFoundError< / span > < span class = "p" > :< / span >
< span class = "n" > HAS_APEX< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > _layer_norm_fwd_fused< / span > < span class = "p" > (< / span >
< span class = "n" > Out< / span > < span class = "p" > ,< / span >
< span class = "n" > A< / span > < span class = "p" > ,< / span >
< span class = "n" > Weight< / span > < span class = "p" > ,< / span >
< span class = "n" > Bias< / span > < span class = "p" > ,< / span >
< span class = "n" > Mean< / span > < span class = "p" > ,< / span > < span class = "n" > Rstd< / span > < span class = "p" > ,< / span >
< span class = "n" > stride< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "p" > ):< / span >
< span class = "c1" > # position of elements processed by this program< / span >
< span class = "n" > row< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > Out< / span > < span class = "o" > +=< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "n" > A< / span > < span class = "o" > +=< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "c1" > # compute mean< / span >
< span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span >
< span class = "n" > _mean< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "k" > for< / span > < span class = "n" > off< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ):< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > off< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > )< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > A< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > ,< / span > < span class = "n" > eviction_policy< / span > < span class = "o" > =< / span > < span class = "s2" > " evict_last" < / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > _mean< / span > < span class = "o" > +=< / span > < span class = "n" > a< / span >
< span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > _mean< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > N< / span >
< span class = "c1" > # compute variance< / span >
< span class = "n" > _var< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "k" > for< / span > < span class = "n" > off< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ):< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > off< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > )< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > A< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > ,< / span > < span class = "n" > eviction_policy< / span > < span class = "o" > =< / span > < span class = "s2" > " evict_last" < / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > where< / span > < span class = "p" > (< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > a< / span > < span class = "o" > -< / span > < span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span >
< span class = "n" > _var< / span > < span class = "o" > +=< / span > < span class = "n" > a< / span > < span class = "o" > *< / span > < span class = "n" > a< / span >
< span class = "n" > var< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > _var< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > N< / span >
< span class = "n" > rstd< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "o" > /< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sqrt< / span > < span class = "p" > (< / span > < span class = "n" > var< / span > < span class = "o" > +< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span >
< span class = "c1" > # write-back mean/rstd< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > Mean< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "p" > ,< / span > < span class = "n" > mean< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > Rstd< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "p" > ,< / span > < span class = "n" > rstd< / span > < span class = "p" > )< / span >
< span class = "c1" > # multiply by weight and add bias< / span >
< span class = "k" > for< / span > < span class = "n" > off< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ):< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > off< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > )< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span >
< span class = "n" > weight< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Weight< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "n" > bias< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Bias< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > A< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > ,< / span > < span class = "n" > eviction_policy< / span > < span class = "o" > =< / span > < span class = "s2" > " evict_first" < / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > a_hat< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "o" > -< / span > < span class = "n" > mean< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "n" > rstd< / span >
< span class = "n" > out< / span > < span class = "o" > =< / span > < span class = "n" > a_hat< / span > < span class = "o" > *< / span > < span class = "n" > weight< / span > < span class = "o" > +< / span > < span class = "n" > bias< / span >
< span class = "c1" > # # write-back< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > Out< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > out< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "c1" > # Backward pass (DA + partial DW + partial DB)< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > _layer_norm_bwd_dx_fused< / span > < span class = "p" > (< / span >
< span class = "n" > _DA< / span > < span class = "p" > ,< / span >
< span class = "n" > _DOut< / span > < span class = "p" > ,< / span >
< span class = "n" > _A< / span > < span class = "p" > ,< / span >
< span class = "n" > Weight< / span > < span class = "p" > ,< / span >
< span class = "n" > Mean< / span > < span class = "p" > ,< / span > < span class = "n" > Rstd< / span > < span class = "p" > ,< / span >
< span class = "n" > stride< / span > < span class = "p" > ,< / span > < span class = "n" > NumRows< / span > < span class = "p" > ,< / span > < span class = "n" > NumCols< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "p" > ):< / span >
< span class = "c1" > # position of elements processed by this program< / span >
< span class = "n" > pid< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > row< / span > < span class = "o" > =< / span > < span class = "n" > pid< / span >
< span class = "n" > A< / span > < span class = "o" > =< / span > < span class = "n" > _A< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "n" > DOut< / span > < span class = "o" > =< / span > < span class = "n" > _DOut< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "n" > DA< / span > < span class = "o" > =< / span > < span class = "n" > _DA< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Mean< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "p" > )< / span >
< span class = "n" > rstd< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Rstd< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "p" > )< / span >
< span class = "c1" > # load data to SRAM< / span >
< span class = "n" > _mean1< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > _mean2< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "k" > for< / span > < span class = "n" > off< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > NumCols< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ):< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > off< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > )< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > NumCols< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > A< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > dout< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > DOut< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > weight< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Weight< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > a_hat< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "o" > -< / span > < span class = "n" > mean< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "n" > rstd< / span >
< span class = "n" > wdout< / span > < span class = "o" > =< / span > < span class = "n" > weight< / span > < span class = "o" > *< / span > < span class = "n" > dout< / span >
< span class = "n" > _mean1< / span > < span class = "o" > +=< / span > < span class = "n" > a_hat< / span > < span class = "o" > *< / span > < span class = "n" > wdout< / span >
< span class = "n" > _mean2< / span > < span class = "o" > +=< / span > < span class = "n" > wdout< / span >
< span class = "n" > mean1< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > _mean1< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > NumCols< / span >
< span class = "n" > mean2< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span >
< span class = "n" > mean2< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > _mean2< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > NumCols< / span >
< span class = "k" > for< / span > < span class = "n" > off< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > NumCols< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ):< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > off< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > )< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > NumCols< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > A< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > dout< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > DOut< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > weight< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Weight< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > a_hat< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "o" > -< / span > < span class = "n" > mean< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "n" > rstd< / span >
< span class = "n" > wdout< / span > < span class = "o" > =< / span > < span class = "n" > weight< / span > < span class = "o" > *< / span > < span class = "n" > dout< / span >
< span class = "n" > da< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > wdout< / span > < span class = "o" > -< / span > < span class = "p" > (< / span > < span class = "n" > a_hat< / span > < span class = "o" > *< / span > < span class = "n" > mean1< / span > < span class = "o" > +< / span > < span class = "n" > mean2< / span > < span class = "p" > ))< / span > < span class = "o" > *< / span > < span class = "n" > rstd< / span >
< span class = "c1" > # write-back dx< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > DA< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > da< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "c1" > # Backward pass (total DW + total DB)< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > _layer_norm_bwd_dwdb< / span > < span class = "p" > (< / span >
< span class = "n" > A< / span > < span class = "p" > ,< / span > < span class = "n" > DOut< / span > < span class = "p" > ,< / span >
< span class = "n" > Mean< / span > < span class = "p" > ,< / span > < span class = "n" > Var< / span > < span class = "p" > ,< / span >
< span class = "n" > DW< / span > < span class = "p" > ,< / span >
< span class = "n" > DB< / span > < span class = "p" > ,< / span >
< span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE_M< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "p" > ):< / span >
< span class = "n" > pid< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > pid< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > )< / span >
< span class = "n" > dw< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ((< / span > < span class = "n" > BLOCK_SIZE_M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > db< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ((< / span > < span class = "n" > BLOCK_SIZE_M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "k" > for< / span > < span class = "n" > i< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_M< / span > < span class = "p" > ):< / span >
< span class = "n" > rows< / span > < span class = "o" > =< / span > < span class = "n" > i< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_M< / span > < span class = "p" > )< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > rows< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > < < / span > < span class = "n" > M< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > cols< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "n" > offs< / span > < span class = "o" > =< / span > < span class = "n" > rows< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > N< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > A< / span > < span class = "o" > +< / span > < span class = "n" > offs< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > dout< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > DOut< / span > < span class = "o" > +< / span > < span class = "n" > offs< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Mean< / span > < span class = "o" > +< / span > < span class = "n" > rows< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > rows< / span > < span class = "o" > < < / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span >
< span class = "n" > rstd< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Var< / span > < span class = "o" > +< / span > < span class = "n" > rows< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > rows< / span > < span class = "o" > < < / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span >
< span class = "n" > a_hat< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "o" > -< / span > < span class = "n" > mean< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ])< / span > < span class = "o" > *< / span > < span class = "n" > rstd< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span >
< span class = "n" > dw< / span > < span class = "o" > +=< / span > < span class = "n" > dout< / span > < span class = "o" > *< / span > < span class = "n" > a_hat< / span >
< span class = "n" > db< / span > < span class = "o" > +=< / span > < span class = "n" > dout< / span >
< span class = "n" > sum_dw< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > dw< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > sum_db< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > db< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > DW< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > sum_dw< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > DB< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > sum_db< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "k" > class< / span > < span class = "nc" > LayerNorm< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > autograd< / span > < span class = "o" > .< / span > < span class = "n" > Function< / span > < span class = "p" > ):< / span >
< span class = "nd" > @staticmethod< / span >
< span class = "k" > def< / span > < span class = "nf" > forward< / span > < span class = "p" > (< / span > < span class = "n" > ctx< / span > < span class = "p" > ,< / span > < span class = "n" > a< / span > < span class = "p" > ,< / span > < span class = "n" > normalized_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > ):< / span >
< span class = "c1" > # allocate output< / span >
< span class = "n" > out< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "p" > )< / span >
< span class = "c1" > # reshape input data into 2D tensor< / span >
< span class = "n" > a_arg< / span > < span class = "o" > =< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > reshape< / span > < span class = "p" > (< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ])< / span >
< span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "o" > =< / span > < span class = "n" > a_arg< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span >
< span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > M< / span > < span class = "p" > ,),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s2" > " cuda" < / span > < span class = "p" > )< / span >
< span class = "n" > rstd< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > M< / span > < span class = "p" > ,),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s2" > " cuda" < / span > < span class = "p" > )< / span >
< span class = "c1" > # Less than 64KB per feature: enqueue fused kernel< / span >
< span class = "n" > MAX_FUSED_SIZE< / span > < span class = "o" > =< / span > < span class = "mi" > 65536< / span > < span class = "o" > //< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > element_size< / span > < span class = "p" > ()< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "nb" > min< / span > < span class = "p" > (< / span > < span class = "n" > MAX_FUSED_SIZE< / span > < span class = "p" > ,< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > next_power_of_2< / span > < span class = "p" > (< / span > < span class = "n" > N< / span > < span class = "p" > ))< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "nb" > max< / span > < span class = "p" > (< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ,< / span > < span class = "mi" > 128< / span > < span class = "p" > )< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "nb" > min< / span > < span class = "p" > (< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ,< / span > < span class = "mi" > 4096< / span > < span class = "p" > )< / span >
< span class = "c1" > # heuristics for number of warps< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "nb" > min< / span > < span class = "p" > (< / span > < span class = "nb" > max< / span > < span class = "p" > (< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "o" > //< / span > < span class = "mi" > 256< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span > < span class = "mi" > 8< / span > < span class = "p" > )< / span >
< span class = "n" > _layer_norm_fwd_fused< / span > < span class = "p" > [(< / span > < span class = "n" > M< / span > < span class = "p" > ,)](< / span >
< span class = "n" > out< / span > < span class = "p" > ,< / span >
< span class = "n" > a_arg< / span > < span class = "p" > ,< / span >
< span class = "n" > weight< / span > < span class = "p" > ,< / span >
< span class = "n" > bias< / span > < span class = "p" > ,< / span >
< span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "n" > rstd< / span > < span class = "p" > ,< / span >
< span class = "n" > a_arg< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ,< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "n" > num_warps< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > save_for_backward< / span > < span class = "p" > (< / span >
< span class = "n" > a< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "n" > rstd< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "n" > BLOCK_SIZE< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "n" > num_warps< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > eps< / span > < span class = "o" > =< / span > < span class = "n" > eps< / span >
< span class = "k" > if< / span > < span class = "nb" > hasattr< / span > < span class = "p" > (< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "s2" > " config" < / span > < span class = "p" > ):< / span >
< span class = "k" > assert< / span > < span class = "n" > bias< / span > < span class = "o" > .< / span > < span class = "n" > config< / span > < span class = "o" > .< / span > < span class = "n" > grad_scale_name< / span > < span class = "o" > ==< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > config< / span > < span class = "o" > .< / span > < span class = "n" > grad_scale_name< / span >
< span class = "n" > grad_scale_name< / span > < span class = "o" > =< / span > < span class = "n" > bias< / span > < span class = "o" > .< / span > < span class = "n" > config< / span > < span class = "o" > .< / span > < span class = "n" > grad_scale_name< / span >
< span class = "k" > else< / span > < span class = "p" > :< / span >
< span class = "n" > grad_scale_name< / span > < span class = "o" > =< / span > < span class = "kc" > None< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > grad_scale_gain_bias_name< / span > < span class = "o" > =< / span > < span class = "n" > grad_scale_name< / span >
< span class = "k" > return< / span > < span class = "n" > out< / span >
< span class = "nd" > @staticmethod< / span >
< span class = "k" > def< / span > < span class = "nf" > backward< / span > < span class = "p" > (< / span > < span class = "n" > ctx< / span > < span class = "p" > ,< / span > < span class = "n" > dout< / span > < span class = "p" > ):< / span >
< span class = "k" > assert< / span > < span class = "n" > dout< / span > < span class = "o" > .< / span > < span class = "n" > is_contiguous< / span > < span class = "p" > ()< / span >
< span class = "n" > a< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "n" > var< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > saved_tensors< / span >
< span class = "c1" > # heuristics for amount of parallel reduction stream for DG/DB< / span >
< span class = "n" > N< / span > < span class = "o" > =< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< span class = "c1" > # allocate output< / span >
< span class = "n" > da< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > dout< / span > < span class = "p" > )< / span >
< span class = "c1" > # enqueue kernel using forward pass heuristics< / span >
< span class = "c1" > # also compute partial sums for DW and DB< / span >
< span class = "n" > x_arg< / span > < span class = "o" > =< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > reshape< / span > < span class = "p" > (< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ])< / span >
< span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "o" > =< / span > < span class = "n" > x_arg< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span >
< span class = "n" > dweight< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ],),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > )< / span >
< span class = "n" > dbias< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ],),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > )< / span >
< span class = "n" > _layer_norm_bwd_dx_fused< / span > < span class = "p" > [(< / span > < span class = "n" > M< / span > < span class = "p" > ,)](< / span >
< span class = "n" > da< / span > < span class = "p" > ,< / span >
< span class = "n" > dout< / span > < span class = "p" > ,< / span >
< span class = "n" > a< / span > < span class = "p" > ,< / span >
< span class = "n" > weight< / span > < span class = "p" > ,< / span >
< span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "n" > var< / span > < span class = "p" > ,< / span >
< span class = "n" > x_arg< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > eps< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE_N< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ,< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > num_warps< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
< span class = "c1" > # accumulate partial sums in separate kernel< / span >
< span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > meta< / span > < span class = "p" > :< / span > < span class = "p" > [< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > cdiv< / span > < span class = "p" > (< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > meta< / span > < span class = "p" > [< / span > < span class = "s2" > " BLOCK_SIZE_N" < / span > < span class = "p" > ])]< / span >
< span class = "n" > _layer_norm_bwd_dwdb< / span > < span class = "p" > [< / span > < span class = "n" > grid< / span > < span class = "p" > ](< / span >
< span class = "n" > a< / span > < span class = "p" > ,< / span > < span class = "n" > dout< / span > < span class = "p" > ,< / span >
< span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "n" > var< / span > < span class = "p" > ,< / span >
< span class = "n" > dweight< / span > < span class = "p" > ,< / span >
< span class = "n" > dbias< / span > < span class = "p" > ,< / span >
< span class = "n" > M< / span > < span class = "p" > ,< / span >
< span class = "n" > N< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE_M< / span > < span class = "o" > =< / span > < span class = "mi" > 32< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE_N< / span > < span class = "o" > =< / span > < span class = "mi" > 128< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "p" > (< / span > < span class = "n" > da< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "n" > dweight< / span > < span class = "p" > ,< / span > < span class = "n" > dbias< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span >
< span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span >
< span class = "kc" > None< / span > < span class = "p" > ,< / span >
< span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span >
< span class = "kc" > None< / span > < span class = "p" > ,< / span >
< span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span >
< span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span >
< span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "p" > ,< / span > < span class = "n" > normalized_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > ):< / span >
< span class = "k" > return< / span > < span class = "n" > LayerNorm< / span > < span class = "o" > .< / span > < span class = "n" > apply< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "p" > ,< / span > < span class = "n" > normalized_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > test_layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "o" > =< / span > < span class = "mf" > 1e-5< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ):< / span >
< span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > manual_seed< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "c1" > # create data< / span >
< span class = "n" > x_shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "n" > w_shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > x_shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "p" > )< / span >
< span class = "n" > weight< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > bias< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "o" > -< / span > < span class = "mf" > 2.3< / span > < span class = "o" > +< / span > < span class = "mf" > 0.5< / span > < span class = "o" > *< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "n" > x_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > )< / span >
< span class = "n" > dy< / span > < span class = "o" > =< / span > < span class = "mf" > .1< / span > < span class = "o" > *< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn_like< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > requires_grad_< / span > < span class = "p" > (< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "c1" > # forward pass< / span >
< span class = "n" > y_tri< / span > < span class = "o" > =< / span > < span class = "n" > layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span >
< span class = "n" > y_ref< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > nn< / span > < span class = "o" > .< / span > < span class = "n" > functional< / span > < span class = "o" > .< / span > < span class = "n" > layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > dtype< / span > < span class = "p" > )< / span >
< span class = "c1" > # backward pass (triton)< / span >
< span class = "n" > y_tri< / span > < span class = "o" > .< / span > < span class = "n" > backward< / span > < span class = "p" > (< / span > < span class = "n" > dy< / span > < span class = "p" > ,< / span > < span class = "n" > retain_graph< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > dx_tri< / span > < span class = "p" > ,< / span > < span class = "n" > dw_tri< / span > < span class = "p" > ,< / span > < span class = "n" > db_tri< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > _< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > .< / span > < span class = "n" > clone< / span > < span class = "p" > ()< / span > < span class = "k" > for< / span > < span class = "n" > _< / span > < span class = "ow" > in< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ]]< / span >
< span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > =< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span >
< span class = "c1" > # backward pass (torch)< / span >
< span class = "n" > y_ref< / span > < span class = "o" > .< / span > < span class = "n" > backward< / span > < span class = "p" > (< / span > < span class = "n" > dy< / span > < span class = "p" > ,< / span > < span class = "n" > retain_graph< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > dx_ref< / span > < span class = "p" > ,< / span > < span class = "n" > dw_ref< / span > < span class = "p" > ,< / span > < span class = "n" > db_ref< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > _< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > .< / span > < span class = "n" > clone< / span > < span class = "p" > ()< / span > < span class = "k" > for< / span > < span class = "n" > _< / span > < span class = "ow" > in< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ]]< / span >
< span class = "c1" > # compare< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > y_tri< / span > < span class = "p" > ,< / span > < span class = "n" > y_ref< / span > < span class = "p" > )< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > dx_tri< / span > < span class = "p" > ,< / span > < span class = "n" > dx_ref< / span > < span class = "p" > )< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > db_tri< / span > < span class = "p" > ,< / span > < span class = "n" > db_ref< / span > < span class = "p" > ,< / span > < span class = "n" > decimal< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > dw_tri< / span > < span class = "p" > ,< / span > < span class = "n" > dw_ref< / span > < span class = "p" > ,< / span > < span class = "n" > decimal< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > perf_report< / span > < span class = "p" > (< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > Benchmark< / span > < span class = "p" > (< / span >
< span class = "n" > x_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' N' < / span > < span class = "p" > ],< / span >
< span class = "n" > x_vals< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "mi" > 512< / span > < span class = "o" > *< / span > < span class = "n" > i< / span > < span class = "k" > for< / span > < span class = "n" > i< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 32< / span > < span class = "p" > )],< / span >
< span class = "n" > line_arg< / span > < span class = "o" > =< / span > < span class = "s1" > ' provider' < / span > < span class = "p" > ,< / span >
< span class = "n" > line_vals< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' triton' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' torch' < / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "p" > ([< / span > < span class = "s1" > ' apex' < / span > < span class = "p" > ]< / span > < span class = "k" > if< / span > < span class = "n" > HAS_APEX< / span > < span class = "k" > else< / span > < span class = "p" > []),< / span >
< span class = "n" > line_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' Triton' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' Torch' < / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "p" > ([< / span > < span class = "s1" > ' Apex' < / span > < span class = "p" > ]< / span > < span class = "k" > if< / span > < span class = "n" > HAS_APEX< / span > < span class = "k" > else< / span > < span class = "p" > []),< / span >
< span class = "n" > styles< / span > < span class = "o" > =< / span > < span class = "p" > [(< / span > < span class = "s1" > ' blue' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' -' < / span > < span class = "p" > ),< / span > < span class = "p" > (< / span > < span class = "s1" > ' green' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' -' < / span > < span class = "p" > ),< / span > < span class = "p" > (< / span > < span class = "s1" > ' orange' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' -' < / span > < span class = "p" > )],< / span >
< span class = "n" > ylabel< / span > < span class = "o" > =< / span > < span class = "s1" > ' GB/s' < / span > < span class = "p" > ,< / span >
< span class = "n" > plot_name< / span > < span class = "o" > =< / span > < span class = "s1" > ' layer-norm' < / span > < span class = "p" > ,< / span >
< span class = "n" > args< / span > < span class = "o" > =< / span > < span class = "p" > {< / span > < span class = "s1" > ' M' < / span > < span class = "p" > :< / span > < span class = "mi" > 4096< / span > < span class = "p" > ,< / span > < span class = "s1" > ' dtype' < / span > < span class = "p" > :< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float16< / span > < span class = "p" > ,< / span > < span class = "s1" > ' mode' < / span > < span class = "p" > :< / span > < span class = "s1" > ' forward' < / span > < span class = "p" > }< / span >
< span class = "p" > )< / span >
< span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > bench_layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > provider< / span > < span class = "p" > ,< / span > < span class = "n" > mode< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "o" > =< / span > < span class = "mf" > 1e-5< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ):< / span >
< span class = "c1" > # create data< / span >
< span class = "n" > x_shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "n" > w_shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > x_shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "p" > )< / span >
< span class = "n" > weight< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > bias< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "o" > -< / span > < span class = "mf" > 2.3< / span > < span class = "o" > +< / span > < span class = "mf" > 0.5< / span > < span class = "o" > *< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "n" > x_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > )< / span >
< span class = "n" > dy< / span > < span class = "o" > =< / span > < span class = "mf" > .1< / span > < span class = "o" > *< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn_like< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > requires_grad_< / span > < span class = "p" > (< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "c1" > # utility functions< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s1" > ' triton' < / span > < span class = "p" > :< / span >
< span class = "n" > y_fwd< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s1" > ' torch' < / span > < span class = "p" > :< / span >
< span class = "n" > y_fwd< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > nn< / span > < span class = "o" > .< / span > < span class = "n" > functional< / span > < span class = "o" > .< / span > < span class = "n" > layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s1" > ' apex' < / span > < span class = "p" > :< / span >
< span class = "n" > apex_layer_norm< / span > < span class = "o" > =< / span > < span class = "n" > apex< / span > < span class = "o" > .< / span > < span class = "n" > normalization< / span > < span class = "o" > .< / span > < span class = "n" > FusedLayerNorm< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > )< / span >
< span class = "n" > y_fwd< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > apex_layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "c1" > # forward pass< / span >
< span class = "k" > if< / span > < span class = "n" > mode< / span > < span class = "o" > ==< / span > < span class = "s1" > ' forward' < / span > < span class = "p" > :< / span >
< span class = "n" > gbps< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > ms< / span > < span class = "p" > :< / span > < span class = "mi" > 2< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > element_size< / span > < span class = "p" > ()< / span > < span class = "o" > /< / span > < span class = "n" > ms< / span > < span class = "o" > *< / span > < span class = "mf" > 1e-6< / span >
< span class = "n" > ms< / span > < span class = "p" > ,< / span > < span class = "n" > min_ms< / span > < span class = "p" > ,< / span > < span class = "n" > max_ms< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > do_bench< / span > < span class = "p" > (< / span > < span class = "n" > y_fwd< / span > < span class = "p" > ,< / span > < span class = "n" > rep< / span > < span class = "o" > =< / span > < span class = "mi" > 500< / span > < span class = "p" > )< / span >
< span class = "c1" > # backward pass< / span >
< span class = "k" > if< / span > < span class = "n" > mode< / span > < span class = "o" > ==< / span > < span class = "s1" > ' backward' < / span > < span class = "p" > :< / span >
< span class = "n" > gbps< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > ms< / span > < span class = "p" > :< / span > < span class = "mi" > 3< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > element_size< / span > < span class = "p" > ()< / span > < span class = "o" > /< / span > < span class = "n" > ms< / span > < span class = "o" > *< / span > < span class = "mf" > 1e-6< / span >
< span class = "n" > y< / span > < span class = "o" > =< / span > < span class = "n" > y_fwd< / span > < span class = "p" > ()< / span >
< span class = "n" > ms< / span > < span class = "p" > ,< / span > < span class = "n" > min_ms< / span > < span class = "p" > ,< / span > < span class = "n" > max_ms< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > do_bench< / span > < span class = "p" > (< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > backward< / span > < span class = "p" > (< / span > < span class = "n" > dy< / span > < span class = "p" > ,< / span > < span class = "n" > retain_graph< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > ),< / span >
< span class = "n" > grad_to_none< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "p" > ],< / span > < span class = "n" > rep< / span > < span class = "o" > =< / span > < span class = "mi" > 500< / span > < span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > ms< / span > < span class = "p" > ),< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > max_ms< / span > < span class = "p" > ),< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > min_ms< / span > < span class = "p" > )< / span >
< span class = "c1" > # test_layer_norm(1151, 8192, torch.float16)< / span >
< span class = "n" > bench_layer_norm< / span > < span class = "o" > .< / span > < span class = "n" > run< / span > < span class = "p" > (< / span > < span class = "n" > save_path< / span > < span class = "o" > =< / span > < span class = "s1" > ' .' < / span > < span class = "p" > ,< / span > < span class = "n" > print_data< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< / pre > < / div >
< / div >
2022-06-18 00:47:57 +00:00
< p class = "sphx-glr-timing" > < strong > Total running time of the script:< / strong > ( 5 minutes 24.904 seconds)< / p >
2022-06-05 21:05:02 +00:00
< div class = "sphx-glr-footer class sphx-glr-footer-example docutils container" id = "sphx-glr-download-getting-started-tutorials-05-layer-norm-py" >
< div class = "sphx-glr-download sphx-glr-download-python docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Python< / span > < span class = "pre" > source< / span > < span class = "pre" > code:< / span > < span class = "pre" > 05-layer-norm.py< / span > < / code > < / a > < / p >
< / div >
< div class = "sphx-glr-download sphx-glr-download-jupyter docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Jupyter< / span > < span class = "pre" > notebook:< / span > < span class = "pre" > 05-layer-norm.ipynb< / span > < / code > < / a > < / p >
< / div >
< / div >
< p class = "sphx-glr-signature" > < a class = "reference external" href = "https://sphinx-gallery.github.io" > Gallery generated by Sphinx-Gallery< / a > < / p >
< / div >
< / div >
< / div >
< footer >
< div class = "rst-footer-buttons" role = "navigation" aria-label = "footer navigation" >
< a href = "../../python-api/triton.html" class = "btn btn-neutral float-right" title = "triton" accesskey = "n" rel = "next" > Next < span class = "fa fa-arrow-circle-right" aria-hidden = "true" > < / span > < / a >
< a href = "04-low-memory-dropout.html" class = "btn btn-neutral float-left" title = "Low-Memory Dropout" accesskey = "p" rel = "prev" > < span class = "fa fa-arrow-circle-left" aria-hidden = "true" > < / span > Previous< / a >
< / div >
< hr / >
< div role = "contentinfo" >
< p >
© Copyright 2020, Philippe Tillet.
< / p >
< / div >
Built with < a href = "https://www.sphinx-doc.org/" > Sphinx< / a > using a
< a href = "https://github.com/readthedocs/sphinx_rtd_theme" > theme< / a >
provided by < a href = "https://readthedocs.org" > Read the Docs< / a > .
< / footer >
< / div >
< / div >
< / section >
< / div >
< div class = "rst-versions" data-toggle = "rst-versions" role = "note" aria-label = "versions" >
< span class = "rst-current-version" data-toggle = "rst-current-version" >
< span class = "fa fa-book" > Other Versions< / span >
v: master
< span class = "fa fa-caret-down" > < / span >
< / span >
< div class = "rst-other-versions" >
< dl >
< dt > Tags< / dt >
< dd > < a href = "../../../v1.1.2/index.html" > v1.1.2< / a > < / dd >
< / dl >
< dl >
< dt > Branches< / dt >
< dd > < a href = "05-layer-norm.html" > master< / a > < / dd >
< / dl >
< / div >
< / div >
< script type = "text/javascript" >
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
< / script >
< / body >
< / html >