2022-02-09 07:15:50 +00:00
<!DOCTYPE html>
< html class = "writer-html5" lang = "en" >
< head >
< meta charset = "utf-8" / >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" / >
< title > Layer Normalization — Triton documentation< / title >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-binder.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-dataframe.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-rendered-html.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/css/custom.css" type = "text/css" / >
<!-- [if lt IE 9]>
< script src = "../../_static/js/html5shiv.min.js" > < / script >
<![endif]-->
< script type = "text/javascript" id = "documentation_options" data-url_root = "../../" src = "../../_static/documentation_options.js" > < / script >
< script data-url_root = "../../" id = "documentation_options" src = "../../_static/documentation_options.js" > < / script >
< script src = "../../_static/jquery.js" > < / script >
< script src = "../../_static/underscore.js" > < / script >
< script src = "../../_static/doctools.js" > < / script >
< script type = "text/javascript" src = "../../_static/js/theme.js" > < / script >
< link rel = "index" title = "Index" href = "../../genindex.html" / >
< link rel = "search" title = "Search" href = "../../search.html" / >
< link rel = "next" title = "triton" href = "../../python-api/triton.html" / >
< link rel = "prev" title = "Low-Memory Dropout" href = "04-low-memory-dropout.html" / >
< / head >
< body class = "wy-body-for-nav" >
< div class = "wy-grid-for-nav" >
< nav data-toggle = "wy-nav-shift" class = "wy-nav-side" >
< div class = "wy-side-scroll" >
< div class = "wy-side-nav-search" >
< a href = "../../index.html" class = "icon icon-home" > Triton
< / a >
< div role = "search" >
< form id = "rtd-search-form" class = "wy-form" action = "../../search.html" method = "get" >
< input type = "text" name = "q" placeholder = "Search docs" / >
< input type = "hidden" name = "check_keywords" value = "yes" / >
< input type = "hidden" name = "area" value = "default" / >
< / form >
< / div >
< / div >
< div class = "wy-menu wy-menu-vertical" data-spy = "affix" role = "navigation" aria-label = "main navigation" >
< p class = "caption" role = "heading" > < span class = "caption-text" > Getting Started< / span > < / p >
< ul class = "current" >
< li class = "toctree-l1" > < a class = "reference internal" href = "../installation.html" > Installation< / a > < / li >
< li class = "toctree-l1 current" > < a class = "reference internal" href = "index.html" > Tutorials< / a > < ul class = "current" >
< li class = "toctree-l2" > < a class = "reference internal" href = "01-vector-add.html" > Vector Addition< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "02-fused-softmax.html" > Fused Softmax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "03-matrix-multiplication.html" > Matrix Multiplication< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "04-low-memory-dropout.html" > Low-Memory Dropout< / a > < / li >
< li class = "toctree-l2 current" > < a class = "current reference internal" href = "#" > Layer Normalization< / a > < / li >
< / ul >
< / li >
< / ul >
< p class = "caption" role = "heading" > < span class = "caption-text" > Python API< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.html" > triton< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.language.html" > triton.language< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.testing.html" > triton.testing< / a > < / li >
< / ul >
< p class = "caption" role = "heading" > < span class = "caption-text" > Programming Guide< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-1/introduction.html" > Introduction< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-2/related-work.html" > Related Work< / a > < / li >
< / ul >
< / div >
< / div >
< / nav >
< section data-toggle = "wy-nav-shift" class = "wy-nav-content-wrap" >
< nav class = "wy-nav-top" aria-label = "top navigation" >
< i data-toggle = "wy-nav-top" class = "fa fa-bars" > < / i >
< a href = "../../index.html" > Triton< / a >
< / nav >
< div class = "wy-nav-content" >
< div class = "rst-content" >
< div role = "navigation" aria-label = "breadcrumbs navigation" >
< ul class = "wy-breadcrumbs" >
< li > < a href = "../../index.html" class = "icon icon-home" > < / a > » < / li >
< li > < a href = "index.html" > Tutorials< / a > » < / li >
< li > Layer Normalization< / li >
< li class = "wy-breadcrumbs-aside" >
< a href = "../../_sources/getting-started/tutorials/05-layer-norm.rst.txt" rel = "nofollow" > View page source< / a >
< / li >
< / ul >
< hr / >
< / div >
< div role = "main" class = "document" itemscope = "itemscope" itemtype = "http://schema.org/Article" >
< div itemprop = "articleBody" >
< div class = "sphx-glr-download-link-note admonition note" >
< p class = "admonition-title" > Note< / p >
< p > Click < a class = "reference internal" href = "#sphx-glr-download-getting-started-tutorials-05-layer-norm-py" > < span class = "std std-ref" > here< / span > < / a >
to download the full example code< / p >
< / div >
< div class = "sphx-glr-example-title section" id = "layer-normalization" >
< span id = "sphx-glr-getting-started-tutorials-05-layer-norm-py" > < / span > < h1 > Layer Normalization< a class = "headerlink" href = "#layer-normalization" title = "Permalink to this headline" > ¶< / a > < / h1 >
< img alt = "05 layer norm" class = "sphx-glr-single-img" src = "../../_images/sphx_glr_05-layer-norm_001.png" / >
< p class = "sphx-glr-script-out" > Out:< / p >
< div class = "sphx-glr-script-out highlight-none notranslate" > < div class = "highlight" > < pre > < span > < / span > layer-norm-backward:
N Triton Torch Apex
2022-02-24 00:41:01 +00:00
0 1024.0 311.088617 99.902435 311.088617
1 1536.0 354.461542 133.083026 341.333333
2 2048.0 427.408686 158.554837 321.254900
3 2560.0 461.954908 182.857144 323.368411
4 3072.0 515.580429 191.999993 319.168834
2022-02-23 00:41:10 +00:00
5 3584.0 551.384634 208.271186 309.410081
2022-02-24 00:41:01 +00:00
6 4096.0 568.231237 219.919464 299.707322
7 4608.0 500.416301 232.825259 287.251954
8 5120.0 529.655159 243.809526 289.811322
9 5632.0 540.671974 244.869560 291.310338
10 6144.0 548.163546 251.631408 288.000001
11 6656.0 534.260858 256.000009 286.279570
2022-02-23 00:41:10 +00:00
12 7168.0 516.612607 254.485198 278.820105
13 7680.0 487.619051 266.743841 284.884090
2022-02-24 00:41:01 +00:00
14 8192.0 468.114289 257.003920 276.912679
15 8704.0 416.958106 267.815384 285.767450
16 9216.0 430.319054 274.081793 289.887291
17 9728.0 439.683593 280.278512 289.308559
18 10240.0 446.025405 287.102804 290.153487
19 10752.0 430.797982 246.699797 289.291486
20 11264.0 429.104745 246.656943 286.980888
2022-02-23 00:41:10 +00:00
21 11776.0 422.457417 250.109737 288.981596
2022-02-24 00:41:01 +00:00
22 12288.0 419.504980 254.893699 294.323369
23 12800.0 414.574901 254.094291 288.993430
24 13312.0 413.309181 252.759501 289.653667
25 13824.0 407.587209 257.390218 292.056329
26 14336.0 395.930964 255.429842 288.644296
27 14848.0 386.918555 257.293872 287.380642
28 15360.0 375.015246 258.513318 286.656296
29 15872.0 368.046389 261.267482 289.679087
2022-02-09 07:15:50 +00:00
< / pre > < / div >
< / div >
< div class = "line-block" >
< div class = "line" > < br / > < / div >
< / div >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton.language< / span > < span class = "k" > as< / span > < span class = "nn" > tl< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton< / span >
< span class = "c1" > # Forward Pass< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > _layer_norm_fwd_fused< / span > < span class = "p" > (< / span > < span class = "n" > X< / span > < span class = "p" > ,< / span > < span class = "n" > Y< / span > < span class = "p" > ,< / span > < span class = "n" > W< / span > < span class = "p" > ,< / span > < span class = "n" > B< / span > < span class = "p" > ,< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > V< / span > < span class = "p" > ,< / span > < span class = "n" > stride< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > ,< / span > < span class = "o" > **< / span > < span class = "n" > META< / span > < span class = "p" > ):< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "n" > META< / span > < span class = "p" > [< / span > < span class = "s1" > ' BLOCK_SIZE' < / span > < span class = "p" > ]< / span >
< span class = "c1" > # position of elements processed by this program< / span >
< span class = "n" > row< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > )< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span >
< span class = "c1" > # offset data pointers to start at the row of interest< / span >
< span class = "n" > X< / span > < span class = "o" > +=< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "n" > Y< / span > < span class = "o" > +=< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "c1" > # load data and cast to float32< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > X< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "c1" > # compute mean< / span >
< span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > N< / span >
< span class = "c1" > # compute std< / span >
< span class = "n" > xmean< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > where< / span > < span class = "p" > (< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "o" > -< / span > < span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span >
< span class = "n" > var< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > xmean< / span > < span class = "o" > *< / span > < span class = "n" > xmean< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > N< / span >
< span class = "n" > rstd< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "o" > /< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sqrt< / span > < span class = "p" > (< / span > < span class = "n" > var< / span > < span class = "o" > +< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span >
< span class = "n" > xhat< / span > < span class = "o" > =< / span > < span class = "n" > xmean< / span > < span class = "o" > *< / span > < span class = "n" > rstd< / span >
< span class = "c1" > # write-back mean/rstd< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "p" > ,< / span > < span class = "n" > mean< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > V< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "p" > ,< / span > < span class = "n" > rstd< / span > < span class = "p" > )< / span >
< span class = "c1" > # multiply by weight and add bias< / span >
< span class = "n" > w< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > W< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "n" > b< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > B< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "n" > y< / span > < span class = "o" > =< / span > < span class = "n" > xhat< / span > < span class = "o" > *< / span > < span class = "n" > w< / span > < span class = "o" > +< / span > < span class = "n" > b< / span >
< span class = "c1" > # write-back< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > Y< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "c1" > # Backward pass (DX + partial DW + partial DB)< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > _layer_norm_bwd_dx_fused< / span > < span class = "p" > (< / span > < span class = "n" > DX< / span > < span class = "p" > ,< / span > < span class = "n" > DY< / span > < span class = "p" > ,< / span > < span class = "n" > DW< / span > < span class = "p" > ,< / span > < span class = "n" > DB< / span > < span class = "p" > ,< / span > < span class = "n" > X< / span > < span class = "p" > ,< / span > < span class = "n" > W< / span > < span class = "p" > ,< / span > < span class = "n" > B< / span > < span class = "p" > ,< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > V< / span > < span class = "p" > ,< / span > < span class = "n" > Lock< / span > < span class = "p" > ,< / span >
< span class = "n" > stride< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > ,< / span >
< span class = "o" > **< / span > < span class = "n" > META< / span > < span class = "p" > ):< / span >
< span class = "n" > GROUP_SIZE_M< / span > < span class = "o" > =< / span > < span class = "n" > META< / span > < span class = "p" > [< / span > < span class = "s1" > ' GROUP_SIZE_M' < / span > < span class = "p" > ]< / span >
< span class = "n" > BLOCK_SIZE_N< / span > < span class = "o" > =< / span > < span class = "n" > META< / span > < span class = "p" > [< / span > < span class = "s1" > ' BLOCK_SIZE_N' < / span > < span class = "p" > ]< / span >
< span class = "c1" > # position of elements processed by this program< / span >
< span class = "n" > row< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > )< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span >
< span class = "c1" > # offset data pointers to start at the row of interest< / span >
< span class = "n" > X< / span > < span class = "o" > +=< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "n" > DY< / span > < span class = "o" > +=< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "n" > DX< / span > < span class = "o" > +=< / span > < span class = "n" > row< / span > < span class = "o" > *< / span > < span class = "n" > stride< / span >
< span class = "c1" > # offset locks and weight/bias gradient pointer< / span >
< span class = "c1" > # each kernel instance accumulates partial sums for< / span >
< span class = "c1" > # DW and DB into one of GROUP_SIZE_M independent buffers< / span >
< span class = "c1" > # these buffers stay in the L2, which allow this kernel< / span >
< span class = "c1" > # to be fast< / span >
< span class = "n" > lock_id< / span > < span class = "o" > =< / span > < span class = "n" > row< / span > < span class = "o" > %< / span > < span class = "n" > GROUP_SIZE_M< / span >
< span class = "n" > Lock< / span > < span class = "o" > +=< / span > < span class = "n" > lock_id< / span >
< span class = "n" > Count< / span > < span class = "o" > =< / span > < span class = "n" > Lock< / span > < span class = "o" > +< / span > < span class = "n" > GROUP_SIZE_M< / span >
< span class = "n" > DW< / span > < span class = "o" > =< / span > < span class = "n" > DW< / span > < span class = "o" > +< / span > < span class = "n" > lock_id< / span > < span class = "o" > *< / span > < span class = "n" > N< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span >
< span class = "n" > DB< / span > < span class = "o" > =< / span > < span class = "n" > DB< / span > < span class = "o" > +< / span > < span class = "n" > lock_id< / span > < span class = "o" > *< / span > < span class = "n" > N< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span >
< span class = "c1" > # load data to SRAM< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > X< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > dy< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > DY< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > w< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > W< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "p" > )< / span >
< span class = "n" > rstd< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > V< / span > < span class = "o" > +< / span > < span class = "n" > row< / span > < span class = "p" > )< / span >
< span class = "c1" > # compute dx< / span >
< span class = "n" > xhat< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > -< / span > < span class = "n" > mean< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "n" > rstd< / span >
< span class = "n" > wdy< / span > < span class = "o" > =< / span > < span class = "n" > w< / span > < span class = "o" > *< / span > < span class = "n" > dy< / span >
< span class = "n" > xhat< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > where< / span > < span class = "p" > (< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > xhat< / span > < span class = "p" > ,< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span >
< span class = "n" > wdy< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > where< / span > < span class = "p" > (< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > wdy< / span > < span class = "p" > ,< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span >
< span class = "n" > mean1< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > xhat< / span > < span class = "o" > *< / span > < span class = "n" > wdy< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > N< / span >
< span class = "n" > mean2< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > wdy< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > N< / span >
< span class = "n" > dx< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > wdy< / span > < span class = "o" > -< / span > < span class = "p" > (< / span > < span class = "n" > xhat< / span > < span class = "o" > *< / span > < span class = "n" > mean1< / span > < span class = "o" > +< / span > < span class = "n" > mean2< / span > < span class = "p" > ))< / span > < span class = "o" > *< / span > < span class = "n" > rstd< / span >
< span class = "c1" > # write-back dx< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > DX< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > dx< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "c1" > # accumulate partial sums for dw/db< / span >
< span class = "n" > partial_dw< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > dy< / span > < span class = "o" > *< / span > < span class = "n" > xhat< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > w< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > )< / span >
< span class = "n" > partial_db< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > dy< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > w< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > )< / span >
< span class = "k" > while< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > atomic_cas< / span > < span class = "p" > (< / span > < span class = "n" > Lock< / span > < span class = "p" > ,< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > ==< / span > < span class = "mi" > 1< / span > < span class = "p" > :< / span >
< span class = "k" > pass< / span >
< span class = "n" > count< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > Count< / span > < span class = "p" > )< / span >
< span class = "c1" > # first store doesn' t accumulate< / span >
< span class = "k" > if< / span > < span class = "n" > count< / span > < span class = "o" > ==< / span > < span class = "mi" > 0< / span > < span class = "p" > :< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > atomic_xchg< / span > < span class = "p" > (< / span > < span class = "n" > Count< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "k" > else< / span > < span class = "p" > :< / span >
< span class = "n" > partial_dw< / span > < span class = "o" > +=< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > DW< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "n" > partial_db< / span > < span class = "o" > +=< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > DB< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > DW< / span > < span class = "p" > ,< / span > < span class = "n" > partial_dw< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > DB< / span > < span class = "p" > ,< / span > < span class = "n" > partial_db< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "c1" > # release lock< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > atomic_xchg< / span > < span class = "p" > (< / span > < span class = "n" > Lock< / span > < span class = "p" > ,< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "c1" > # Backward pass (total DW + total DB)< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > _layer_norm_bwd_dwdb< / span > < span class = "p" > (< / span > < span class = "n" > DW< / span > < span class = "p" > ,< / span > < span class = "n" > DB< / span > < span class = "p" > ,< / span > < span class = "n" > FINAL_DW< / span > < span class = "p" > ,< / span > < span class = "n" > FINAL_DB< / span > < span class = "p" > ,< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "o" > **< / span > < span class = "n" > meta< / span > < span class = "p" > ):< / span >
< span class = "n" > pid< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > BLOCK_SIZE_M< / span > < span class = "o" > =< / span > < span class = "n" > meta< / span > < span class = "p" > [< / span > < span class = "s1" > ' BLOCK_SIZE_M' < / span > < span class = "p" > ]< / span >
< span class = "n" > BLOCK_SIZE_N< / span > < span class = "o" > =< / span > < span class = "n" > meta< / span > < span class = "p" > [< / span > < span class = "s1" > ' BLOCK_SIZE_N' < / span > < span class = "p" > ]< / span >
< span class = "n" > cols< / span > < span class = "o" > =< / span > < span class = "n" > pid< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > )< / span >
< span class = "n" > dw< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ((< / span > < span class = "n" > BLOCK_SIZE_M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > db< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > ((< / span > < span class = "n" > BLOCK_SIZE_M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_N< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "k" > for< / span > < span class = "n" > i< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE_M< / span > < span class = "p" > ):< / span >
< span class = "n" > rows< / span > < span class = "o" > =< / span > < span class = "n" > i< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > meta< / span > < span class = "p" > [< / span > < span class = "s1" > ' BLOCK_SIZE_M' < / span > < span class = "p" > ])< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > rows< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > < < / span > < span class = "n" > M< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > cols< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "n" > offs< / span > < span class = "o" > =< / span > < span class = "n" > rows< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span > < span class = "o" > *< / span > < span class = "n" > N< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > [< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "p" > :]< / span >
< span class = "n" > dw< / span > < span class = "o" > +=< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > DW< / span > < span class = "o" > +< / span > < span class = "n" > offs< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span >
< span class = "n" > db< / span > < span class = "o" > +=< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > DB< / span > < span class = "o" > +< / span > < span class = "n" > offs< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "p" > )< / span >
< span class = "n" > sum_dw< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > dw< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > sum_db< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > db< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > FINAL_DW< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > sum_dw< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > FINAL_DB< / span > < span class = "o" > +< / span > < span class = "n" > cols< / span > < span class = "p" > ,< / span > < span class = "n" > sum_db< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > cols< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "k" > class< / span > < span class = "nc" > LayerNorm< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > autograd< / span > < span class = "o" > .< / span > < span class = "n" > Function< / span > < span class = "p" > ):< / span >
< span class = "nd" > @staticmethod< / span >
< span class = "k" > def< / span > < span class = "nf" > forward< / span > < span class = "p" > (< / span > < span class = "n" > ctx< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > normalized_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > ):< / span >
< span class = "c1" > # allocate output< / span >
< span class = "n" > y< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "c1" > # reshape input data into 2D tensor< / span >
< span class = "n" > x_arg< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > reshape< / span > < span class = "p" > (< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ])< / span >
< span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "o" > =< / span > < span class = "n" > x_arg< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span >
< span class = "n" > mean< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > )< / span >
< span class = "n" > rstd< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > )< / span >
< span class = "c1" > # Less than 64KB per feature: enqueue fused kernel< / span >
< span class = "n" > MAX_FUSED_SIZE< / span > < span class = "o" > =< / span > < span class = "mi" > 65536< / span > < span class = "o" > //< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > element_size< / span > < span class = "p" > ()< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "nb" > min< / span > < span class = "p" > (< / span > < span class = "n" > MAX_FUSED_SIZE< / span > < span class = "p" > ,< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > next_power_of_2< / span > < span class = "p" > (< / span > < span class = "n" > N< / span > < span class = "p" > ))< / span >
< span class = "k" > if< / span > < span class = "n" > N< / span > < span class = "o" > > < / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > :< / span >
< span class = "k" > raise< / span > < span class = "ne" > RuntimeError< / span > < span class = "p" > (< / span > < span class = "s2" > " This layer norm doesn' t support feature dim > = 64KB." < / span > < span class = "p" > )< / span >
< span class = "c1" > # heuristics for number of warps< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "nb" > min< / span > < span class = "p" > (< / span > < span class = "nb" > max< / span > < span class = "p" > (< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "o" > //< / span > < span class = "mi" > 256< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span > < span class = "mi" > 8< / span > < span class = "p" > )< / span >
< span class = "c1" > # enqueue kernel< / span >
< span class = "n" > _layer_norm_fwd_fused< / span > < span class = "p" > [(< / span > < span class = "n" > M< / span > < span class = "p" > ,)](< / span > < span class = "n" > x_arg< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "n" > rstd< / span > < span class = "p" > ,< / span >
< span class = "n" > x_arg< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ,< / span > < span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "n" > num_warps< / span > < span class = "p" > )< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > save_for_backward< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > mean< / span > < span class = "p" > ,< / span > < span class = "n" > rstd< / span > < span class = "p" > )< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "n" > BLOCK_SIZE< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "n" > num_warps< / span >
< span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > eps< / span > < span class = "o" > =< / span > < span class = "n" > eps< / span >
< span class = "k" > return< / span > < span class = "n" > y< / span >
< span class = "nd" > @staticmethod< / span >
< span class = "k" > def< / span > < span class = "nf" > backward< / span > < span class = "p" > (< / span > < span class = "n" > ctx< / span > < span class = "p" > ,< / span > < span class = "n" > dy< / span > < span class = "p" > ):< / span >
< span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > w< / span > < span class = "p" > ,< / span > < span class = "n" > b< / span > < span class = "p" > ,< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > saved_tensors< / span >
< span class = "c1" > # heuristics for amount of parallel reduction stream for DG/DB< / span >
< span class = "n" > N< / span > < span class = "o" > =< / span > < span class = "n" > w< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< span class = "n" > GROUP_SIZE_M< / span > < span class = "o" > =< / span > < span class = "mi" > 64< / span >
< span class = "k" > if< / span > < span class = "n" > N< / span > < span class = "o" > < =< / span > < span class = "mi" > 8192< / span > < span class = "p" > :< / span > < span class = "n" > GROUP_SIZE_M< / span > < span class = "o" > =< / span > < span class = "mi" > 96< / span >
< span class = "k" > if< / span > < span class = "n" > N< / span > < span class = "o" > < =< / span > < span class = "mi" > 4096< / span > < span class = "p" > :< / span > < span class = "n" > GROUP_SIZE_M< / span > < span class = "o" > =< / span > < span class = "mi" > 128< / span >
< span class = "k" > if< / span > < span class = "n" > N< / span > < span class = "o" > < =< / span > < span class = "mi" > 1024< / span > < span class = "p" > :< / span > < span class = "n" > GROUP_SIZE_M< / span > < span class = "o" > =< / span > < span class = "mi" > 256< / span >
< span class = "c1" > # allocate output< / span >
< span class = "n" > locks< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "o" > *< / span > < span class = "n" > GROUP_SIZE_M< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > int32< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > )< / span >
< span class = "n" > _dw< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > GROUP_SIZE_M< / span > < span class = "p" > ,< / span > < span class = "n" > w< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > w< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > )< / span >
< span class = "n" > _db< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > GROUP_SIZE_M< / span > < span class = "p" > ,< / span > < span class = "n" > w< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > w< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > )< / span >
< span class = "n" > dw< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > w< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ],),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > w< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > w< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > )< / span >
< span class = "n" > db< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty< / span > < span class = "p" > ((< / span > < span class = "n" > w< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ],),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > w< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > w< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > )< / span >
< span class = "n" > dx< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > dy< / span > < span class = "p" > )< / span >
< span class = "c1" > # enqueue kernel using forward pass heuristics< / span >
< span class = "c1" > # also compute partial sums for DW and DB< / span >
< span class = "n" > x_arg< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > reshape< / span > < span class = "p" > (< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ])< / span >
< span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "o" > =< / span > < span class = "n" > x_arg< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span >
< span class = "n" > _layer_norm_bwd_dx_fused< / span > < span class = "p" > [(< / span > < span class = "n" > M< / span > < span class = "p" > ,)](< / span > < span class = "n" > dx< / span > < span class = "p" > ,< / span > < span class = "n" > dy< / span > < span class = "p" > ,< / span > < span class = "n" > _dw< / span > < span class = "p" > ,< / span > < span class = "n" > _db< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > w< / span > < span class = "p" > ,< / span > < span class = "n" > b< / span > < span class = "p" > ,< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "p" > ,< / span > < span class = "n" > locks< / span > < span class = "p" > ,< / span >
< span class = "n" > x_arg< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > eps< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE_N< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ,< / span >
< span class = "n" > GROUP_SIZE_M< / span > < span class = "o" > =< / span > < span class = "n" > GROUP_SIZE_M< / span > < span class = "p" > ,< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "n" > ctx< / span > < span class = "o" > .< / span > < span class = "n" > num_warps< / span > < span class = "p" > )< / span >
< span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > meta< / span > < span class = "p" > :< / span > < span class = "p" > [< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > cdiv< / span > < span class = "p" > (< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > meta< / span > < span class = "p" > [< / span > < span class = "s1" > ' BLOCK_SIZE_N' < / span > < span class = "p" > ])]< / span >
< span class = "c1" > # accumulate partial sums in separate kernel< / span >
< span class = "n" > _layer_norm_bwd_dwdb< / span > < span class = "p" > [< / span > < span class = "n" > grid< / span > < span class = "p" > ](< / span > < span class = "n" > _dw< / span > < span class = "p" > ,< / span > < span class = "n" > _db< / span > < span class = "p" > ,< / span > < span class = "n" > dw< / span > < span class = "p" > ,< / span > < span class = "n" > db< / span > < span class = "p" > ,< / span > < span class = "n" > GROUP_SIZE_M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE_M< / span > < span class = "o" > =< / span > < span class = "mi" > 32< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE_N< / span > < span class = "o" > =< / span > < span class = "mi" > 128< / span > < span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > dx< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "n" > dw< / span > < span class = "p" > ,< / span > < span class = "n" > db< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span >
< span class = "n" > layer_norm< / span > < span class = "o" > =< / span > < span class = "n" > LayerNorm< / span > < span class = "o" > .< / span > < span class = "n" > apply< / span >
< span class = "k" > def< / span > < span class = "nf" > test_layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "o" > =< / span > < span class = "mf" > 1e-5< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ):< / span >
< span class = "c1" > # create data< / span >
< span class = "n" > x_shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "n" > w_shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > x_shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "p" > )< / span >
< span class = "n" > weight< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > bias< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "o" > -< / span > < span class = "mf" > 2.3< / span > < span class = "o" > +< / span > < span class = "mf" > 0.5< / span > < span class = "o" > *< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "n" > x_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > )< / span >
< span class = "n" > dy< / span > < span class = "o" > =< / span > < span class = "mf" > .1< / span > < span class = "o" > *< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn_like< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > requires_grad_< / span > < span class = "p" > (< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "c1" > # forward pass< / span >
< span class = "n" > y_tri< / span > < span class = "o" > =< / span > < span class = "n" > layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span >
< span class = "n" > y_ref< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > nn< / span > < span class = "o" > .< / span > < span class = "n" > functional< / span > < span class = "o" > .< / span > < span class = "n" > layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > dtype< / span > < span class = "p" > )< / span >
< span class = "c1" > # backward pass (triton)< / span >
< span class = "n" > y_tri< / span > < span class = "o" > .< / span > < span class = "n" > backward< / span > < span class = "p" > (< / span > < span class = "n" > dy< / span > < span class = "p" > ,< / span > < span class = "n" > retain_graph< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > dx_tri< / span > < span class = "p" > ,< / span > < span class = "n" > dw_tri< / span > < span class = "p" > ,< / span > < span class = "n" > db_tri< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > _< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > .< / span > < span class = "n" > clone< / span > < span class = "p" > ()< / span > < span class = "k" > for< / span > < span class = "n" > _< / span > < span class = "ow" > in< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ]]< / span >
< span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > =< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "kc" > None< / span >
< span class = "c1" > # backward pass (torch)< / span >
< span class = "n" > y_ref< / span > < span class = "o" > .< / span > < span class = "n" > backward< / span > < span class = "p" > (< / span > < span class = "n" > dy< / span > < span class = "p" > ,< / span > < span class = "n" > retain_graph< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > dx_ref< / span > < span class = "p" > ,< / span > < span class = "n" > dw_ref< / span > < span class = "p" > ,< / span > < span class = "n" > db_ref< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > _< / span > < span class = "o" > .< / span > < span class = "n" > grad< / span > < span class = "o" > .< / span > < span class = "n" > clone< / span > < span class = "p" > ()< / span > < span class = "k" > for< / span > < span class = "n" > _< / span > < span class = "ow" > in< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ]]< / span >
< span class = "c1" > # compare< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > y_tri< / span > < span class = "p" > ,< / span > < span class = "n" > y_ref< / span > < span class = "p" > )< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > dx_tri< / span > < span class = "p" > ,< / span > < span class = "n" > dx_ref< / span > < span class = "p" > )< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > db_tri< / span > < span class = "p" > ,< / span > < span class = "n" > db_ref< / span > < span class = "p" > ,< / span > < span class = "n" > decimal< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > assert_almost_equal< / span > < span class = "p" > (< / span > < span class = "n" > dw_tri< / span > < span class = "p" > ,< / span > < span class = "n" > dw_ref< / span > < span class = "p" > ,< / span > < span class = "n" > decimal< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > perf_report< / span > < span class = "p" > (< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > Benchmark< / span > < span class = "p" > (< / span >
< span class = "n" > x_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' N' < / span > < span class = "p" > ],< / span >
< span class = "n" > x_vals< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "mi" > 512< / span > < span class = "o" > *< / span > < span class = "n" > i< / span > < span class = "k" > for< / span > < span class = "n" > i< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 32< / span > < span class = "p" > )],< / span >
< span class = "n" > line_arg< / span > < span class = "o" > =< / span > < span class = "s1" > ' provider' < / span > < span class = "p" > ,< / span >
< span class = "n" > line_vals< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' triton' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' torch' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' apex' < / span > < span class = "p" > ],< / span >
< span class = "n" > line_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' Triton' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' Torch' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' Apex' < / span > < span class = "p" > ],< / span >
< span class = "n" > styles< / span > < span class = "o" > =< / span > < span class = "p" > [(< / span > < span class = "s1" > ' blue' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' -' < / span > < span class = "p" > ),< / span > < span class = "p" > (< / span > < span class = "s1" > ' green' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' -' < / span > < span class = "p" > ),< / span > < span class = "p" > (< / span > < span class = "s1" > ' orange' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' -' < / span > < span class = "p" > )],< / span >
< span class = "n" > ylabel< / span > < span class = "o" > =< / span > < span class = "s1" > ' GB/s' < / span > < span class = "p" > ,< / span >
< span class = "n" > plot_name< / span > < span class = "o" > =< / span > < span class = "s1" > ' layer-norm-backward' < / span > < span class = "p" > ,< / span >
< span class = "n" > args< / span > < span class = "o" > =< / span > < span class = "p" > {< / span > < span class = "s1" > ' M' < / span > < span class = "p" > :< / span > < span class = "mi" > 4096< / span > < span class = "p" > ,< / span > < span class = "s1" > ' dtype' < / span > < span class = "p" > :< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float16< / span > < span class = "p" > ,< / span > < span class = "s1" > ' mode' < / span > < span class = "p" > :< / span > < span class = "s1" > ' backward' < / span > < span class = "p" > }< / span >
< span class = "p" > )< / span >
< span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > bench_layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > provider< / span > < span class = "p" > ,< / span > < span class = "n" > mode< / span > < span class = "o" > =< / span > < span class = "s1" > ' backward' < / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "o" > =< / span > < span class = "mf" > 1e-5< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ):< / span >
< span class = "c1" > # create data< / span >
< span class = "n" > x_shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > )< / span >
< span class = "n" > w_shape< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > x_shape< / span > < span class = "p" > [< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ],< / span > < span class = "p" > )< / span >
< span class = "n" > weight< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > bias< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "o" > -< / span > < span class = "mf" > 2.3< / span > < span class = "o" > +< / span > < span class = "mf" > 0.5< / span > < span class = "o" > *< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "n" > x_shape< / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > dtype< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > )< / span >
< span class = "n" > dy< / span > < span class = "o" > =< / span > < span class = "mf" > .1< / span > < span class = "o" > *< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn_like< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > requires_grad_< / span > < span class = "p" > (< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< span class = "c1" > # utility functions< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s1" > ' triton' < / span > < span class = "p" > :< / span >
< span class = "n" > y_fwd< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s1" > ' torch' < / span > < span class = "p" > :< / span >
< span class = "n" > y_fwd< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > nn< / span > < span class = "o" > .< / span > < span class = "n" > functional< / span > < span class = "o" > .< / span > < span class = "n" > layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > w_shape< / span > < span class = "p" > ,< / span > < span class = "n" > weight< / span > < span class = "p" > ,< / span > < span class = "n" > bias< / span > < span class = "p" > ,< / span > < span class = "n" > eps< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s1" > ' apex' < / span > < span class = "p" > :< / span >
< span class = "kn" > import< / span > < span class = "nn" > apex< / span >
< span class = "n" > apex_layer_norm< / span > < span class = "o" > =< / span > < span class = "n" > apex< / span > < span class = "o" > .< / span > < span class = "n" > normalization< / span > < span class = "o" > .< / span > < span class = "n" > FusedLayerNorm< / span > < span class = "p" > (< / span > < span class = "n" > w_shape< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "p" > )< / span >
< span class = "n" > y_fwd< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > apex_layer_norm< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "c1" > # forward pass< / span >
< span class = "k" > if< / span > < span class = "n" > mode< / span > < span class = "o" > ==< / span > < span class = "s1" > ' forward' < / span > < span class = "p" > :< / span >
< span class = "n" > gbps< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > ms< / span > < span class = "p" > :< / span > < span class = "mi" > 2< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > element_size< / span > < span class = "p" > ()< / span > < span class = "o" > /< / span > < span class = "n" > ms< / span > < span class = "o" > *< / span > < span class = "mf" > 1e-6< / span >
< span class = "n" > ms< / span > < span class = "p" > ,< / span > < span class = "n" > min_ms< / span > < span class = "p" > ,< / span > < span class = "n" > max_ms< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > do_bench< / span > < span class = "p" > (< / span > < span class = "n" > y_fwd< / span > < span class = "p" > ,< / span > < span class = "n" > rep< / span > < span class = "o" > =< / span > < span class = "mi" > 500< / span > < span class = "p" > )< / span >
< span class = "c1" > # backward pass< / span >
< span class = "k" > if< / span > < span class = "n" > mode< / span > < span class = "o" > ==< / span > < span class = "s1" > ' backward' < / span > < span class = "p" > :< / span >
< span class = "n" > gbps< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > ms< / span > < span class = "p" > :< / span > < span class = "mi" > 3< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > element_size< / span > < span class = "p" > ()< / span > < span class = "o" > /< / span > < span class = "n" > ms< / span > < span class = "o" > *< / span > < span class = "mf" > 1e-6< / span >
< span class = "n" > y< / span > < span class = "o" > =< / span > < span class = "n" > y_fwd< / span > < span class = "p" > ()< / span >
< span class = "n" > ms< / span > < span class = "p" > ,< / span > < span class = "n" > min_ms< / span > < span class = "p" > ,< / span > < span class = "n" > max_ms< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > do_bench< / span > < span class = "p" > (< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > backward< / span > < span class = "p" > (< / span > < span class = "n" > dy< / span > < span class = "p" > ,< / span > < span class = "n" > retain_graph< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > ),< / span >
< span class = "n" > grad_to_none< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "p" > ],< / span > < span class = "n" > rep< / span > < span class = "o" > =< / span > < span class = "mi" > 500< / span > < span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > ms< / span > < span class = "p" > ),< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > max_ms< / span > < span class = "p" > ),< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > min_ms< / span > < span class = "p" > )< / span >
< span class = "n" > bench_layer_norm< / span > < span class = "o" > .< / span > < span class = "n" > run< / span > < span class = "p" > (< / span > < span class = "n" > save_path< / span > < span class = "o" > =< / span > < span class = "s1" > ' .' < / span > < span class = "p" > ,< / span > < span class = "n" > print_data< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< / pre > < / div >
< / div >
2022-02-24 00:41:01 +00:00
< p class = "sphx-glr-timing" > < strong > Total running time of the script:< / strong > ( 2 minutes 12.791 seconds)< / p >
2022-02-09 07:15:50 +00:00
< div class = "sphx-glr-footer class sphx-glr-footer-example docutils container" id = "sphx-glr-download-getting-started-tutorials-05-layer-norm-py" >
< div class = "sphx-glr-download sphx-glr-download-python docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Python< / span > < span class = "pre" > source< / span > < span class = "pre" > code:< / span > < span class = "pre" > 05-layer-norm.py< / span > < / code > < / a > < / p >
< / div >
< div class = "sphx-glr-download sphx-glr-download-jupyter docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Jupyter< / span > < span class = "pre" > notebook:< / span > < span class = "pre" > 05-layer-norm.ipynb< / span > < / code > < / a > < / p >
< / div >
< / div >
< p class = "sphx-glr-signature" > < a class = "reference external" href = "https://sphinx-gallery.github.io" > Gallery generated by Sphinx-Gallery< / a > < / p >
< / div >
< / div >
< / div >
< footer >
< div class = "rst-footer-buttons" role = "navigation" aria-label = "footer navigation" >
< a href = "../../python-api/triton.html" class = "btn btn-neutral float-right" title = "triton" accesskey = "n" rel = "next" > Next < span class = "fa fa-arrow-circle-right" aria-hidden = "true" > < / span > < / a >
< a href = "04-low-memory-dropout.html" class = "btn btn-neutral float-left" title = "Low-Memory Dropout" accesskey = "p" rel = "prev" > < span class = "fa fa-arrow-circle-left" aria-hidden = "true" > < / span > Previous< / a >
< / div >
< hr / >
< div role = "contentinfo" >
< p >
© Copyright 2020, Philippe Tillet.
< / p >
< / div >
Built with < a href = "https://www.sphinx-doc.org/" > Sphinx< / a > using a
< a href = "https://github.com/readthedocs/sphinx_rtd_theme" > theme< / a >
provided by < a href = "https://readthedocs.org" > Read the Docs< / a > .
< / footer >
< / div >
< / div >
< / section >
< / div >
< div class = "rst-versions" data-toggle = "rst-versions" role = "note" aria-label = "versions" >
< span class = "rst-current-version" data-toggle = "rst-current-version" >
< span class = "fa fa-book" > Other Versions< / span >
v: v1.1.2
< span class = "fa fa-caret-down" > < / span >
< / span >
< div class = "rst-other-versions" >
< dl >
< dt > Tags< / dt >
< dd > < a href = "05-layer-norm.html" > v1.1.2< / a > < / dd >
< / dl >
< dl >
< dt > Branches< / dt >
< dd > < a href = "../../../master/index.html" > master< / a > < / dd >
< / dl >
< / div >
< / div >
< script type = "text/javascript" >
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
< / script >
< / body >
< / html >