2022-06-05 21:05:02 +00:00
<!DOCTYPE html>
< html class = "writer-html5" lang = "en" >
< head >
< meta charset = "utf-8" / >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" / >
< title > Layer Normalization — Triton documentation< / title >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-binder.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-dataframe.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-rendered-html.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/css/custom.css" type = "text/css" / >
<!-- [if lt IE 9]>
< script src = "../../_static/js/html5shiv.min.js" > < / script >
<![endif]-->
< script type = "text/javascript" id = "documentation_options" data-url_root = "../../" src = "../../_static/documentation_options.js" > < / script >
< script data-url_root = "../../" id = "documentation_options" src = "../../_static/documentation_options.js" > < / script >
< script src = "../../_static/jquery.js" > < / script >
< script src = "../../_static/underscore.js" > < / script >
< script src = "../../_static/doctools.js" > < / script >
< script type = "text/javascript" src = "../../_static/js/theme.js" > < / script >
< link rel = "index" title = "Index" href = "../../genindex.html" / >
< link rel = "search" title = "Search" href = "../../search.html" / >
2022-07-14 07:22:19 +00:00
< link rel = "next" title = "Fused Attention" href = "06-fused-attention.html" / >
2022-06-05 21:05:02 +00:00
< link rel = "prev" title = "Low-Memory Dropout" href = "04-low-memory-dropout.html" / >
< / head >
< body class = "wy-body-for-nav" >
< div class = "wy-grid-for-nav" >
< nav data-toggle = "wy-nav-shift" class = "wy-nav-side" >
< div class = "wy-side-scroll" >
< div class = "wy-side-nav-search" >
< a href = "../../index.html" class = "icon icon-home" > Triton
< / a >
< div role = "search" >
< form id = "rtd-search-form" class = "wy-form" action = "../../search.html" method = "get" >
< input type = "text" name = "q" placeholder = "Search docs" / >
< input type = "hidden" name = "check_keywords" value = "yes" / >
< input type = "hidden" name = "area" value = "default" / >
< / form >
< / div >
< / div >
< div class = "wy-menu wy-menu-vertical" data-spy = "affix" role = "navigation" aria-label = "main navigation" >
< p class = "caption" role = "heading" > < span class = "caption-text" > Getting Started< / span > < / p >
< ul class = "current" >
< li class = "toctree-l1" > < a class = "reference internal" href = "../installation.html" > Installation< / a > < / li >
< li class = "toctree-l1 current" > < a class = "reference internal" href = "index.html" > Tutorials< / a > < ul class = "current" >
< li class = "toctree-l2" > < a class = "reference internal" href = "01-vector-add.html" > Vector Addition< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "02-fused-softmax.html" > Fused Softmax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "03-matrix-multiplication.html" > Matrix Multiplication< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "04-low-memory-dropout.html" > Low-Memory Dropout< / a > < / li >
< li class = "toctree-l2 current" > < a class = "current reference internal" href = "#" > Layer Normalization< / a > < / li >
2022-07-14 07:22:19 +00:00
< li class = "toctree-l2" > < a class = "reference internal" href = "06-fused-attention.html" > Fused Attention< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "07-libdevice-function.html" > Libdevice function< / a > < / li >
2022-06-05 21:05:02 +00:00
< / ul >
< / li >
< / ul >
< p class = "caption" role = "heading" > < span class = "caption-text" > Python API< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.html" > triton< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.language.html" > triton.language< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.testing.html" > triton.testing< / a > < / li >
< / ul >
< p class = "caption" role = "heading" > < span class = "caption-text" > Programming Guide< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-1/introduction.html" > Introduction< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-2/related-work.html" > Related Work< / a > < / li >
< / ul >
< / div >
< / div >
< / nav >
< section data-toggle = "wy-nav-shift" class = "wy-nav-content-wrap" >
< nav class = "wy-nav-top" aria-label = "top navigation" >
< i data-toggle = "wy-nav-top" class = "fa fa-bars" > < / i >
< a href = "../../index.html" > Triton< / a >
< / nav >
< div class = "wy-nav-content" >
< div class = "rst-content" >
< div role = "navigation" aria-label = "breadcrumbs navigation" >
< ul class = "wy-breadcrumbs" >
< li > < a href = "../../index.html" class = "icon icon-home" > < / a > » < / li >
< li > < a href = "index.html" > Tutorials< / a > » < / li >
< li > Layer Normalization< / li >
< li class = "wy-breadcrumbs-aside" >
< a href = "../../_sources/getting-started/tutorials/05-layer-norm.rst.txt" rel = "nofollow" > View page source< / a >
< / li >
< / ul >
< hr / >
< / div >
< div role = "main" class = "document" itemscope = "itemscope" itemtype = "http://schema.org/Article" >
< div itemprop = "articleBody" >
< div class = "sphx-glr-download-link-note admonition note" >
< p class = "admonition-title" > Note< / p >
< p > Click < a class = "reference internal" href = "#sphx-glr-download-getting-started-tutorials-05-layer-norm-py" > < span class = "std std-ref" > here< / span > < / a >
to download the full example code< / p >
< / div >
< div class = "sphx-glr-example-title section" id = "layer-normalization" >
< span id = "sphx-glr-getting-started-tutorials-05-layer-norm-py" > < / span > < h1 > Layer Normalization< a class = "headerlink" href = "#layer-normalization" title = "Permalink to this headline" > ¶< / a > < / h1 >
< img alt = "05 layer norm" class = "sphx-glr-single-img" src = "../../_images/sphx_glr_05-layer-norm_001.png" / >
< p class = "sphx-glr-script-out" > Out:< / p >
< div class = "sphx-glr-script-out highlight-none notranslate" > < div class = "highlight" > < pre > < span > < / span > layer-norm:
N Triton Torch Apex
2022-08-17 00:49:36 +00:00
0 1024.0 585.142849 277.694907 468.114273
2022-07-17 00:49:40 +00:00
1 1536.0 630.153868 323.368435 511.999982
2022-08-17 00:49:36 +00:00
2 2048.0 682.666643 337.814445 520.126988
3 2560.0 694.237267 362.477870 512.000013
2022-08-16 01:02:32 +00:00
4 3072.0 712.347810 378.092307 501.551037
2022-08-17 00:49:36 +00:00
5 3584.0 725.873439 384.859062 451.527536
6 4096.0 728.177767 381.023256 451.972420
7 4608.0 670.254540 396.387087 428.651163
8 5120.0 688.403381 397.669909 420.102563
9 5632.0 704.000002 395.228063 413.357796
10 6144.0 702.171410 402.885254 413.042029
11 6656.0 700.631610 400.360920 400.360920
12 7168.0 690.891575 392.767108 382.293315
13 7680.0 678.895043 393.846167 386.415087
14 8192.0 636.271854 394.795186 377.729113
15 8704.0 624.502255 389.005597 379.465939
16 9216.0 604.327881 406.214877 382.010363
17 9728.0 585.142883 408.524944 383.369452
18 10240.0 564.965524 409.600010 382.803739
19 10752.0 546.133312 411.559798 380.601764
20 11264.0 531.634232 404.997742 371.595879
2022-08-10 00:48:34 +00:00
21 11776.0 520.486200 409.599991 377.587162
2022-08-17 00:49:36 +00:00
22 12288.0 516.031509 413.911572 383.251457
23 12800.0 504.433489 409.599981 377.163903
24 13312.0 494.180982 406.473303 377.645399
25 13824.0 482.934503 412.656711 379.389355
26 14336.0 471.967074 402.414053 370.558967
27 14848.0 461.297068 407.492270 373.534584
28 15360.0 454.269882 406.214870 377.511515
29 15872.0 447.098578 409.599996 377.343238
2022-06-05 21:05:02 +00:00
< / pre > < / div >
< / div >
< div class = "line-block" >
< div class = "line" > < br / > < / div >
< / div >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton.language< / span > < span class = "k" > as< / span > < span class = "nn" > tl< / span >
< span class = "k" > try< / span > < span class = "p" > :< / span >
< span class = "c1" > # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it< / span >
< span class = "c1" > # should not be added to extras_require in setup.py.< / span >
< span class = "kn" > import< / span > < span class = "nn" > apex< / span >
< span class = "n" > HAS_APEX< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span >
< span class = "k" > except< / span > < span class = "ne" > ModuleNotFoundError< / span > < span class = "p" > :< / span >
< span class = "n" > HAS_APEX< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > _layer_norm_fwd_fused< / span > < span class = "p" > (< / span >
< span class = "n" > Out< / span > < span class = "p" > ,< / span >
< span class = "n" > A< / span > < span class = "p" > ,< / span >
< span class = "n" > Weight< / span > < span class = "p" > ,< / span >
< span class = "n" > Bias< / span > < span class = "p" > ,< / span >
< span class = "n" > Mean< / span > < span class = "p" > ,< / span > < span class = "n" > Rstd< / span > < span class = "p" > ,< / span >
< span class = "n" > stride< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "p" > ):< / span >
< span class = "c1" > # position of elements processed by this program< / span >
< span class = "n" > row< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > Out< / span > < span class = "o" > +=< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "n" > A< / span > < span class = "o" > +=< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "c1" > # compute mean< / span >
< span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span >
< span class = "n" > _mean< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "k" > for< / span > < span class = "n" > off< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ):< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > off< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > )< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > A< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > ,< / span > < span class = "n" > eviction_policy< / span > < span class = "o" > =< / span > < span class = "s2" > " evict_last" < / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > _mean< / span > < span class = "o" > +=< / span > < span class = "n" > a< / span >
< span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > _mean< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > N< / span >
< span class = "c1" > # compute variance< / span >
< span class = "n" > _var< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "k" > for< / span > < span class = "n" > off< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ):< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > off< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > )< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > A< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > ,< / span > < span class = "n" > eviction_policy< / span > < span class = "o" > =< / span > < span class = "s2" > " evict_last" < / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > where< / span > < span class = "p" > (< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > a< / span > < span class = "o" > -< / span > < span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span >
< span class = "n" > _var< / span > < span class = "o" > +=< / span > < span class = "n" > a< / span > < span class = "o" > *< / span > < span class = "n" > a< / span >
< span class = "n" > var< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > _var< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > N< / span >
< span class = "n" > rstd< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "o" > /< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sqrt< / span > < span class = "p" > (< / span > < span class = "n" > var< / span > < span class = "o" > +< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span >
< span class = "c1" > # write-back mean/rstd< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > Mean< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "p" > ,< / span > < span class = "n" > mean< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > Rstd< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "p" > ,< / span > < span class = "n" > rstd< / span > < span class = "p" > )< / span >
< span class = "c1" > # multiply by weight and add bias< / span >
< span class = "k" > for< / span > < span class = "n" > off< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ):< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > off< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > )< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span >
< span class = "n" > weight< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Weight< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "n" > bias< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Bias< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > A< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > ,< / span > < span class = "n" > eviction_policy< / span > < span class = "o" > =< / span > < span class = "s2" > " evict_first" < / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > a_hat< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "o" > -< / span > < span class = "n" > mean< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "n" > rstd< / span >
< span class = "n" > out< / span > < span class = "o" > =< / span > < span class = "n" > a_hat< / span > < span class = "o" > *< / span > < span class = "n" > weight< / span > < span class = "o" > +< / span > < span class = "n" > bias< / span >
< span class = "c1" > # # write-back< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > Out< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > out< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "c1" > # Backward pass (DA + partial DW + partial DB)< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > _layer_norm_bwd_dx_fused< / span > < span class = "p" > (< / span >
< span class = "n" > _DA< / span > < span class = "p" > ,< / span >
< span class = "n" > _DOut< / span > < span class = "p" > ,< / span >
< span class = "n" > _A< / span > < span class = "p" > ,< / span >
< span class = "n" > Weight< / span > < span class = "p" > ,< / span >
< span class = "n" > Mean< / span > < span class = "p" > ,< / span > < span class = "n" > Rstd< / span > < span class = "p" > ,< / span >
< span class = "n" > stride< / span > < span class = "p" > ,< / span > < span class = "n" > NumRows< / span > < span class = "p" > ,< / span > < span class = "n" > NumCols< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "p" > ):< / span >
< span class = "c1" > # position of elements processed by this program< / span >
< span class = "n" > pid< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > row< / span > < span class = "o" > =< / span > < span class = "n" > pid< / span >
< span class = "n" > A< / span > < span class = "o" > =< / span > < span class = "n" > _A< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "n" > DOut< / span > < span class = "o" > =< / span > < span class = "n" > _DOut< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "n" > DA< / span > < span class = "o" > =< / span > < span class = "n" > _DA< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Mean< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "p" > )< / span >
< span class = "n" > rstd< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Rstd< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "p" > )< / span >
< span class = "c1" > # load data to SRAM< / span >
< span class = "n" > _mean1< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > _mean2< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ([< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ],< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "k" > for< / span > < span class = "n" > off< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > NumCols< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ):< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > off< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > )< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > NumCols< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > A< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > dout< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > DOut< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > weight< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Weight< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > a_hat< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "o" > -< / span > < span class = "n" > mean< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "n" > rstd< / span >
< span class = "n" > wdout< / span > < span class = "o" > =< / span > < span class = "n" > weight< / span > < span class = "o" > *< / span > < span class = "n" > dout< / span >
< span class = "n" > _mean1< / span > < span class = "o" > +=< / span > < span class = "n" > a_hat< / span > < span class = "o" > *< / span > < span class = "n" > wdout< / span >
< span class = "n" > _mean2< / span > < span class = "o" > +=< / span > < span class = "n" > wdout< / span >
< span class = "n" > mean1< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > _mean1< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > NumCols< / span >
< span class = "n" > mean2< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span >
< span class = "n" > mean2< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > _mean2< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > NumCols< / span >
< span class = "k" > for< / span > < span class = "n" > off< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > NumCols< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ):< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > off< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > )< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > NumCols< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > A< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > dout< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > DOut< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > weight< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Weight< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > a_hat< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "o" > -< / span > < span class = "n" > mean< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "n" > rstd< / span >
< span class = "n" > wdout< / span > < span class = "o" > =< / span > < span class = "n" > weight< / span > < span class = "o" > *< / span > < span class = "n" > dout< / span >
< span class = "n" > da< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > wdout< / span > < span class = "o" > -< / span > < span class = "p" > (< / span > < span class = "n" > a_hat< / span > < span class = "o" > *< / span > < span class = "n" > mean1< / span > < span class = "o" > +< / span > < span class = "n" > mean2< / span > < span class = "p" > ))< / span > < span class = "o" > *< / span > < span class = "n" > rstd< / span >
< span class = "c1" > # write-back dx< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > DA< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > da< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "c1" > # Backward pass (total DW + total DB)< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > _layer_norm_bwd_dwdb< / span > < span class = "p" > (< / span >
< span class = "n" > A< / span > < span class = "p" > ,< / span > < span class = "n" > DOut< / span > < span class = "p" > ,< / span >
< span class = "n" > Mean< / span > < span class = "p" > ,< / span > < span class = "n" > Var< / span > < span class = "p" > ,< / span >
< span class = "n" > DW< / span > < span class = "p" > ,< / span >
< span class = "n" > DB< / span > < span class = "p" > ,< / span >
< span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE_M< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "p" > ):< / span >
< span class = "n" > pid< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > pid< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > )< / span >
< span class = "n" > dw< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ((< / span > < span class = "n" > BLOCK_SIZE_M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > db< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ((< / span > < span class = "n" > BLOCK_SIZE_M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
2022-07-14 07:22:19 +00:00
< span class = "n" > UNROLL< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "o" > =< / span > < span class = "mi" > 4< / span >
< span class = "k" > for< / span > < span class = "n" > i< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_M< / span > < span class = "o" > *< / span > < span class = "n" > UNROLL< / span > < span class = "p" > ):< / span >
< span class = "k" > for< / span > < span class = "n" > j< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "n" > UNROLL< / span > < span class = "p" > ):< / span >
< span class = "n" > rows< / span > < span class = "o" > =< / span > < span class = "n" > i< / span > < span class = "o" > +< / span > < span class = "n" > j< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK_SIZE_M< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_M< / span > < span class = "p" > )< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > rows< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > < < / span > < span class = "n" > M< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > cols< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "n" > offs< / span > < span class = "o" > =< / span > < span class = "n" > rows< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > N< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span >
< span class = "n" > a< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > A< / span > < span class = "o" > +< / span > < span class = "n" > offs< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > dout< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > DOut< / span > < span class = "o" > +< / span > < span class = "n" > offs< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Mean< / span > < span class = "o" > +< / span > < span class = "n" > rows< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > rows< / span > < span class = "o" > < < / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span >
< span class = "n" > rstd< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Var< / span > < span class = "o" > +< / span > < span class = "n" > rows< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > rows< / span > < span class = "o" > < < / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span >
< span class = "n" > a_hat< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "o" > -< / span > < span class = "n" > mean< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ])< / span > < span class = "o" > *< / span > < span class = "n" > rstd< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span >
< span class = "n" > dw< / span > < span class = "o" > +=< / span > < span class = "n" > dout< / span > < span class = "o" > *< / span > < span class = "n" > a_hat< / span >
< span class = "n" > db< / span > < span class = "o" > +=< / span > < span class = "n" > dout< / span >
2022-06-05 21:05:02 +00:00
< span class = "n" > sum_dw< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > dw< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > sum_db< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > db< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > DW< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > sum_dw< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > DB< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > sum_db< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "k" > class< / span > < span class = "nc" > LayerNorm< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > autograd< / span > < span class = "o" > .< / span > < span class = "n" > Function< / span > < span class = "p" > ):< / span >
< span class = "nd" > @staticmethod< / span >
< span class = "k" > def< / span > < span class = "nf" > forward< / span > < span class = "p" > (< / span > < span class = "n" > ctx< / span > < span class = "p" > ,< / span > < span class = "n" > a< / span > < span class = "p" > ,< / span > < span class = "n" > normalized_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > ):< / span >
< span class = "c1" > # allocate output< / span >
< span class = "n" > out< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "p" > )< / span >
< span class = "c1" > # reshape input data into 2D tensor< / span >
< span class = "n" > a_arg< / span > < span class = "o" > =< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > reshape< / span > < span class = "p" > (< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ])< / span >
< span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "o" > =< / span > < span class = "n" > a_arg< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span >
< span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > M< / span > < span class = "p" > ,),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s2" > " cuda" < / span > < span class = "p" > )< / span >
< span class = "n" > rstd< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > M< / span > < span class = "p" > ,),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s2" > " cuda" < / span > < span class = "p" > )< / span >
< span class = "c1" > # Less than 64KB per feature: enqueue fused kernel< / span >
< span class = "n" > MAX_FUSED_SIZE< / span > < span class = "o" > =< / span > < span class = "mi" > 65536< / span > < span class = "o" > //< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > element_size< / span > < span class = "p" > ()< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "nb" > min< / span > < span class = "p" > (< / span > < span class = "n" > MAX_FUSED_SIZE< / span > < span class = "p" > ,< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > next_power_of_2< / span > < span class = "p" > (< / span > < span class = "n" > N< / span > < span class = "p" > ))< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "nb" > max< / span > < span class = "p" > (< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ,< / span > < span class = "mi" > 128< / span > < span class = "p" > )< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "nb" > min< / span > < span class = "p" > (< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ,< / span > < span class = "mi" > 4096< / span > < span class = "p" > )< / span >
< span class = "c1" > # heuristics for number of warps< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "nb" > min< / span > < span class = "p" > (< / span > < span class = "nb" > max< / span > < span class = "p" > (< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "o" > //< / span > < span class = "mi" > 256< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span > < span class = "mi" > 8< / span > < span class = "p" > )< / span >
< span class = "n" > _layer_norm_fwd_fused< / span > < span class = "p" > [(< / span > < span class = "n" > M< / span > < span class = "p" > ,)](< / span >
< span class = "n" > out< / span > < span class = "p" > ,< / span >
< span class = "n" > a_arg< / span > < span class = "p" > ,< / span >
< span class = "n" > weight< / span > < span class = "p" > ,< / span >
< span class = "n" > bias< / span > < span class = "p" > ,< / span >
< span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "n" > rstd< / span > < span class = "p" > ,< / span >
< span class = "n" > a_arg< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ,< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "n" > num_warps< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > save_for_backward< / span > < span class = "p" > (< / span >
< span class = "n" > a< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "n" > rstd< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "n" > BLOCK_SIZE< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "n" > num_warps< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > eps< / span > < span class = "o" > =< / span > < span class = "n" > eps< / span >
< span class = "k" > if< / span > < span class = "nb" > hasattr< / span > < span class = "p" > (< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "s2" > " config" < / span > < span class = "p" > ):< / span >
< span class = "k" > assert< / span > < span class = "n" > bias< / span > < span class = "o" > .< / span > < span class = "n" > config< / span > < span class = "o" > .< / span > < span class = "n" > grad_scale_name< / span > < span class = "o" > ==< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > config< / span > < span class = "o" > .< / span > < span class = "n" > grad_scale_name< / span >
< span class = "n" > grad_scale_name< / span > < span class = "o" > =< / span > < span class = "n" > bias< / span > < span class = "o" > .< / span > < span class = "n" > config< / span > < span class = "o" > .< / span > < span class = "n" > grad_scale_name< / span >
< span class = "k" > else< / span > < span class = "p" > :< / span >
< span class = "n" > grad_scale_name< / span > < span class = "o" > =< / span > < span class = "kc" > None< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > grad_scale_gain_bias_name< / span > < span class = "o" > =< / span > < span class = "n" > grad_scale_name< / span >
< span class = "k" > return< / span > < span class = "n" > out< / span >
< span class = "nd" > @staticmethod< / span >
< span class = "k" > def< / span > < span class = "nf" > backward< / span > < span class = "p" > (< / span > < span class = "n" > ctx< / span > < span class = "p" > ,< / span > < span class = "n" > dout< / span > < span class = "p" > ):< / span >
< span class = "k" > assert< / span > < span class = "n" > dout< / span > < span class = "o" > .< / span > < span class = "n" > is_contiguous< / span > < span class = "p" > ()< / span >
< span class = "n" > a< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "n" > var< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > saved_tensors< / span >
< span class = "c1" > # heuristics for amount of parallel reduction stream for DG/DB< / span >
< span class = "n" > N< / span > < span class = "o" > =< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< span class = "c1" > # allocate output< / span >
< span class = "n" > da< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > dout< / span > < span class = "p" > )< / span >
< span class = "c1" > # enqueue kernel using forward pass heuristics< / span >
< span class = "c1" > # also compute partial sums for DW and DB< / span >
< span class = "n" > x_arg< / span > < span class = "o" > =< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > reshape< / span > < span class = "p" > (< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "n" > a< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ])< / span >
< span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "o" > =< / span > < span class = "n" > x_arg< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span >
< span class = "n" > dweight< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ],),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > )< / span >
< span class = "n" > dbias< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ],),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > )< / span >
< span class = "n" > _layer_norm_bwd_dx_fused< / span > < span class = "p" > [(< / span > < span class = "n" > M< / span > < span class = "p" > ,)](< / span >
< span class = "n" > da< / span > < span class = "p" > ,< / span >
< span class = "n" > dout< / span > < span class = "p" > ,< / span >
< span class = "n" > a< / span > < span class = "p" > ,< / span >
< span class = "n" > weight< / span > < span class = "p" > ,< / span >
< span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "n" > var< / span > < span class = "p" > ,< / span >
< span class = "n" > x_arg< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > eps< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE_N< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ,< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > num_warps< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
2022-07-14 07:22:19 +00:00
< span class = "k" > if< / span > < span class = "n" > N< / span > < span class = "o" > > < / span > < span class = "mi" > 10240< / span > < span class = "p" > :< / span >
< span class = "n" > BLOCK_SIZE_N< / span > < span class = "o" > =< / span > < span class = "mi" > 128< / span >
< span class = "n" > BLOCK_SIZE_M< / span > < span class = "o" > =< / span > < span class = "mi" > 32< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "mi" > 4< / span >
< span class = "k" > else< / span > < span class = "p" > :< / span >
< span class = "c1" > # maximize occupancy for small N< / span >
< span class = "n" > BLOCK_SIZE_N< / span > < span class = "o" > =< / span > < span class = "mi" > 16< / span >
< span class = "n" > BLOCK_SIZE_M< / span > < span class = "o" > =< / span > < span class = "mi" > 16< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "mi" > 8< / span >
2022-06-05 21:05:02 +00:00
< span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > meta< / span > < span class = "p" > :< / span > < span class = "p" > [< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > cdiv< / span > < span class = "p" > (< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > meta< / span > < span class = "p" > [< / span > < span class = "s2" > " BLOCK_SIZE_N" < / span > < span class = "p" > ])]< / span >
< span class = "n" > _layer_norm_bwd_dwdb< / span > < span class = "p" > [< / span > < span class = "n" > grid< / span > < span class = "p" > ](< / span >
< span class = "n" > a< / span > < span class = "p" > ,< / span > < span class = "n" > dout< / span > < span class = "p" > ,< / span >
< span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "n" > var< / span > < span class = "p" > ,< / span >
< span class = "n" > dweight< / span > < span class = "p" > ,< / span >
< span class = "n" > dbias< / span > < span class = "p" > ,< / span >
< span class = "n" > M< / span > < span class = "p" > ,< / span >
< span class = "n" > N< / span > < span class = "p" > ,< / span >
2022-07-14 07:22:19 +00:00
< span class = "n" > BLOCK_SIZE_M< / span > < span class = "o" > =< / span > < span class = "n" > BLOCK_SIZE_M< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE_N< / span > < span class = "o" > =< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ,< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "n" > num_warps< / span >
2022-06-05 21:05:02 +00:00
< span class = "p" > )< / span >
2022-07-14 07:22:19 +00:00
< span class = "k" > return< / span > < span class = "p" > (< / span > < span class = "n" > da< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "n" > dweight< / span > < span class = "p" > ,< / span > < span class = "n" > dbias< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > )< / span >
2022-06-05 21:05:02 +00:00
< span class = "k" > def< / span > < span class = "nf" > layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "p" > ,< / span > < span class = "n" > normalized_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > ):< / span >
< span class = "k" > return< / span > < span class = "n" > LayerNorm< / span > < span class = "o" > .< / span > < span class = "n" > apply< / span > < span class = "p" > (< / span > < span class = "n" > a< / span > < span class = "p" > ,< / span > < span class = "n" > normalized_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > test_layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "o" > =< / span > < span class = "mf" > 1e-5< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ):< / span >
< span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > manual_seed< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "c1" > # create data< / span >
< span class = "n" > x_shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "n" > w_shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > x_shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "p" > )< / span >
< span class = "n" > weight< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > bias< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "o" > -< / span > < span class = "mf" > 2.3< / span > < span class = "o" > +< / span > < span class = "mf" > 0.5< / span > < span class = "o" > *< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "n" > x_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > )< / span >
< span class = "n" > dy< / span > < span class = "o" > =< / span > < span class = "mf" > .1< / span > < span class = "o" > *< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn_like< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > requires_grad_< / span > < span class = "p" > (< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "c1" > # forward pass< / span >
< span class = "n" > y_tri< / span > < span class = "o" > =< / span > < span class = "n" > layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span >
< span class = "n" > y_ref< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > nn< / span > < span class = "o" > .< / span > < span class = "n" > functional< / span > < span class = "o" > .< / span > < span class = "n" > layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > dtype< / span > < span class = "p" > )< / span >
< span class = "c1" > # backward pass (triton)< / span >
< span class = "n" > y_tri< / span > < span class = "o" > .< / span > < span class = "n" > backward< / span > < span class = "p" > (< / span > < span class = "n" > dy< / span > < span class = "p" > ,< / span > < span class = "n" > retain_graph< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > dx_tri< / span > < span class = "p" > ,< / span > < span class = "n" > dw_tri< / span > < span class = "p" > ,< / span > < span class = "n" > db_tri< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > _< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > .< / span > < span class = "n" > clone< / span > < span class = "p" > ()< / span > < span class = "k" > for< / span > < span class = "n" > _< / span > < span class = "ow" > in< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ]]< / span >
< span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > =< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span >
< span class = "c1" > # backward pass (torch)< / span >
< span class = "n" > y_ref< / span > < span class = "o" > .< / span > < span class = "n" > backward< / span > < span class = "p" > (< / span > < span class = "n" > dy< / span > < span class = "p" > ,< / span > < span class = "n" > retain_graph< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > dx_ref< / span > < span class = "p" > ,< / span > < span class = "n" > dw_ref< / span > < span class = "p" > ,< / span > < span class = "n" > db_ref< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > _< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > .< / span > < span class = "n" > clone< / span > < span class = "p" > ()< / span > < span class = "k" > for< / span > < span class = "n" > _< / span > < span class = "ow" > in< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ]]< / span >
< span class = "c1" > # compare< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > y_tri< / span > < span class = "p" > ,< / span > < span class = "n" > y_ref< / span > < span class = "p" > )< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > dx_tri< / span > < span class = "p" > ,< / span > < span class = "n" > dx_ref< / span > < span class = "p" > )< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > db_tri< / span > < span class = "p" > ,< / span > < span class = "n" > db_ref< / span > < span class = "p" > ,< / span > < span class = "n" > decimal< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > dw_tri< / span > < span class = "p" > ,< / span > < span class = "n" > dw_ref< / span > < span class = "p" > ,< / span > < span class = "n" > decimal< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > perf_report< / span > < span class = "p" > (< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > Benchmark< / span > < span class = "p" > (< / span >
< span class = "n" > x_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' N' < / span > < span class = "p" > ],< / span >
< span class = "n" > x_vals< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "mi" > 512< / span > < span class = "o" > *< / span > < span class = "n" > i< / span > < span class = "k" > for< / span > < span class = "n" > i< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 32< / span > < span class = "p" > )],< / span >
< span class = "n" > line_arg< / span > < span class = "o" > =< / span > < span class = "s1" > ' provider' < / span > < span class = "p" > ,< / span >
< span class = "n" > line_vals< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' triton' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' torch' < / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "p" > ([< / span > < span class = "s1" > ' apex' < / span > < span class = "p" > ]< / span > < span class = "k" > if< / span > < span class = "n" > HAS_APEX< / span > < span class = "k" > else< / span > < span class = "p" > []),< / span >
< span class = "n" > line_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' Triton' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' Torch' < / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "p" > ([< / span > < span class = "s1" > ' Apex' < / span > < span class = "p" > ]< / span > < span class = "k" > if< / span > < span class = "n" > HAS_APEX< / span > < span class = "k" > else< / span > < span class = "p" > []),< / span >
< span class = "n" > styles< / span > < span class = "o" > =< / span > < span class = "p" > [(< / span > < span class = "s1" > ' blue' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' -' < / span > < span class = "p" > ),< / span > < span class = "p" > (< / span > < span class = "s1" > ' green' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' -' < / span > < span class = "p" > ),< / span > < span class = "p" > (< / span > < span class = "s1" > ' orange' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' -' < / span > < span class = "p" > )],< / span >
< span class = "n" > ylabel< / span > < span class = "o" > =< / span > < span class = "s1" > ' GB/s' < / span > < span class = "p" > ,< / span >
< span class = "n" > plot_name< / span > < span class = "o" > =< / span > < span class = "s1" > ' layer-norm' < / span > < span class = "p" > ,< / span >
< span class = "n" > args< / span > < span class = "o" > =< / span > < span class = "p" > {< / span > < span class = "s1" > ' M' < / span > < span class = "p" > :< / span > < span class = "mi" > 4096< / span > < span class = "p" > ,< / span > < span class = "s1" > ' dtype' < / span > < span class = "p" > :< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float16< / span > < span class = "p" > ,< / span > < span class = "s1" > ' mode' < / span > < span class = "p" > :< / span > < span class = "s1" > ' forward' < / span > < span class = "p" > }< / span >
< span class = "p" > )< / span >
< span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > bench_layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > provider< / span > < span class = "p" > ,< / span > < span class = "n" > mode< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "o" > =< / span > < span class = "mf" > 1e-5< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ):< / span >
< span class = "c1" > # create data< / span >
< span class = "n" > x_shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "n" > w_shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > x_shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "p" > )< / span >
< span class = "n" > weight< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > bias< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "o" > -< / span > < span class = "mf" > 2.3< / span > < span class = "o" > +< / span > < span class = "mf" > 0.5< / span > < span class = "o" > *< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "n" > x_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > )< / span >
< span class = "n" > dy< / span > < span class = "o" > =< / span > < span class = "mf" > .1< / span > < span class = "o" > *< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn_like< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > requires_grad_< / span > < span class = "p" > (< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "c1" > # utility functions< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s1" > ' triton' < / span > < span class = "p" > :< / span >
< span class = "n" > y_fwd< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s1" > ' torch' < / span > < span class = "p" > :< / span >
< span class = "n" > y_fwd< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > nn< / span > < span class = "o" > .< / span > < span class = "n" > functional< / span > < span class = "o" > .< / span > < span class = "n" > layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s1" > ' apex' < / span > < span class = "p" > :< / span >
< span class = "n" > apex_layer_norm< / span > < span class = "o" > =< / span > < span class = "n" > apex< / span > < span class = "o" > .< / span > < span class = "n" > normalization< / span > < span class = "o" > .< / span > < span class = "n" > FusedLayerNorm< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > )< / span >
< span class = "n" > y_fwd< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > apex_layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "c1" > # forward pass< / span >
< span class = "k" > if< / span > < span class = "n" > mode< / span > < span class = "o" > ==< / span > < span class = "s1" > ' forward' < / span > < span class = "p" > :< / span >
< span class = "n" > gbps< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > ms< / span > < span class = "p" > :< / span > < span class = "mi" > 2< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > element_size< / span > < span class = "p" > ()< / span > < span class = "o" > /< / span > < span class = "n" > ms< / span > < span class = "o" > *< / span > < span class = "mf" > 1e-6< / span >
< span class = "n" > ms< / span > < span class = "p" > ,< / span > < span class = "n" > min_ms< / span > < span class = "p" > ,< / span > < span class = "n" > max_ms< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > do_bench< / span > < span class = "p" > (< / span > < span class = "n" > y_fwd< / span > < span class = "p" > ,< / span > < span class = "n" > rep< / span > < span class = "o" > =< / span > < span class = "mi" > 500< / span > < span class = "p" > )< / span >
< span class = "c1" > # backward pass< / span >
< span class = "k" > if< / span > < span class = "n" > mode< / span > < span class = "o" > ==< / span > < span class = "s1" > ' backward' < / span > < span class = "p" > :< / span >
< span class = "n" > gbps< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > ms< / span > < span class = "p" > :< / span > < span class = "mi" > 3< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > element_size< / span > < span class = "p" > ()< / span > < span class = "o" > /< / span > < span class = "n" > ms< / span > < span class = "o" > *< / span > < span class = "mf" > 1e-6< / span >
< span class = "n" > y< / span > < span class = "o" > =< / span > < span class = "n" > y_fwd< / span > < span class = "p" > ()< / span >
< span class = "n" > ms< / span > < span class = "p" > ,< / span > < span class = "n" > min_ms< / span > < span class = "p" > ,< / span > < span class = "n" > max_ms< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > do_bench< / span > < span class = "p" > (< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > backward< / span > < span class = "p" > (< / span > < span class = "n" > dy< / span > < span class = "p" > ,< / span > < span class = "n" > retain_graph< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > ),< / span >
< span class = "n" > grad_to_none< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "p" > ],< / span > < span class = "n" > rep< / span > < span class = "o" > =< / span > < span class = "mi" > 500< / span > < span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > ms< / span > < span class = "p" > ),< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > max_ms< / span > < span class = "p" > ),< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > min_ms< / span > < span class = "p" > )< / span >
< span class = "c1" > # test_layer_norm(1151, 8192, torch.float16)< / span >
< span class = "n" > bench_layer_norm< / span > < span class = "o" > .< / span > < span class = "n" > run< / span > < span class = "p" > (< / span > < span class = "n" > save_path< / span > < span class = "o" > =< / span > < span class = "s1" > ' .' < / span > < span class = "p" > ,< / span > < span class = "n" > print_data< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< / pre > < / div >
< / div >
2022-08-17 00:49:36 +00:00
< p class = "sphx-glr-timing" > < strong > Total running time of the script:< / strong > ( 5 minutes 37.218 seconds)< / p >
2022-06-05 21:05:02 +00:00
< div class = "sphx-glr-footer class sphx-glr-footer-example docutils container" id = "sphx-glr-download-getting-started-tutorials-05-layer-norm-py" >
< div class = "sphx-glr-download sphx-glr-download-python docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Python< / span > < span class = "pre" > source< / span > < span class = "pre" > code:< / span > < span class = "pre" > 05-layer-norm.py< / span > < / code > < / a > < / p >
< / div >
< div class = "sphx-glr-download sphx-glr-download-jupyter docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Jupyter< / span > < span class = "pre" > notebook:< / span > < span class = "pre" > 05-layer-norm.ipynb< / span > < / code > < / a > < / p >
< / div >
< / div >
< p class = "sphx-glr-signature" > < a class = "reference external" href = "https://sphinx-gallery.github.io" > Gallery generated by Sphinx-Gallery< / a > < / p >
< / div >
< / div >
< / div >
< footer >
< div class = "rst-footer-buttons" role = "navigation" aria-label = "footer navigation" >
2022-07-14 07:22:19 +00:00
< a href = "06-fused-attention.html" class = "btn btn-neutral float-right" title = "Fused Attention" accesskey = "n" rel = "next" > Next < span class = "fa fa-arrow-circle-right" aria-hidden = "true" > < / span > < / a >
2022-06-05 21:05:02 +00:00
< a href = "04-low-memory-dropout.html" class = "btn btn-neutral float-left" title = "Low-Memory Dropout" accesskey = "p" rel = "prev" > < span class = "fa fa-arrow-circle-left" aria-hidden = "true" > < / span > Previous< / a >
< / div >
< hr / >
< div role = "contentinfo" >
< p >
© Copyright 2020, Philippe Tillet.
< / p >
< / div >
Built with < a href = "https://www.sphinx-doc.org/" > Sphinx< / a > using a
< a href = "https://github.com/readthedocs/sphinx_rtd_theme" > theme< / a >
provided by < a href = "https://readthedocs.org" > Read the Docs< / a > .
< / footer >
< / div >
< / div >
< / section >
< / div >
< div class = "rst-versions" data-toggle = "rst-versions" role = "note" aria-label = "versions" >
< span class = "rst-current-version" data-toggle = "rst-current-version" >
< span class = "fa fa-book" > Other Versions< / span >
v: master
< span class = "fa fa-caret-down" > < / span >
< / span >
< div class = "rst-other-versions" >
< dl >
< dt > Tags< / dt >
< dd > < a href = "../../../v1.1.2/index.html" > v1.1.2< / a > < / dd >
< / dl >
< dl >
< dt > Branches< / dt >
< dd > < a href = "05-layer-norm.html" > master< / a > < / dd >
< / dl >
< / div >
< / div >
< script type = "text/javascript" >
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
< / script >
< / body >
< / html >