2022-02-09 07:15:50 +00:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								<!DOCTYPE html> 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< html  class = "writer-html5"  lang = "en"  > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< head > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < meta  charset = "utf-8"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < meta  name = "viewport"  content = "width=device-width, initial-scale=1.0"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < title > Layer Normalization —  Triton  documentation< / title > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/css/theme.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/pygments.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/pygments.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/css/theme.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/gallery.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/gallery-binder.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/gallery-dataframe.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/gallery-rendered-html.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/css/custom.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  <!-- [if lt IE 9]>
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < script  src = "../../_static/js/html5shiv.min.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  <![endif]--> 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      < script  type = "text/javascript"  id = "documentation_options"  data-url_root = "../../"  src = "../../_static/documentation_options.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < script  data-url_root = "../../"  id = "documentation_options"  src = "../../_static/documentation_options.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < script  src = "../../_static/jquery.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < script  src = "../../_static/underscore.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < script  src = "../../_static/doctools.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < script  type = "text/javascript"  src = "../../_static/js/theme.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < link  rel = "index"  title = "Index"  href = "../../genindex.html"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < link  rel = "search"  title = "Search"  href = "../../search.html"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < link  rel = "next"  title = "triton"  href = "../../python-api/triton.html"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < link  rel = "prev"  title = "Low-Memory Dropout"  href = "04-low-memory-dropout.html"  / >  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / head > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< body  class = "wy-body-for-nav" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								   
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < div  class = "wy-grid-for-nav" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < nav  data-toggle = "wy-nav-shift"  class = "wy-nav-side" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      < div  class = "wy-side-scroll" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < div  class = "wy-side-nav-search"  > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            < a  href = "../../index.html"  class = "icon icon-home" >  Triton
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          < / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< div  role = "search" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < form  id = "rtd-search-form"  class = "wy-form"  action = "../../search.html"  method = "get" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < input  type = "text"  name = "q"  placeholder = "Search docs"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < input  type = "hidden"  name = "check_keywords"  value = "yes"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < input  type = "hidden"  name = "area"  value = "default"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < / form > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < div  class = "wy-menu wy-menu-vertical"  data-spy = "affix"  role = "navigation"  aria-label = "main navigation" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								              < p  class = "caption"  role = "heading" > < span  class = "caption-text" > Getting Started< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< ul  class = "current" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../installation.html" > Installation< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< li  class = "toctree-l1 current" > < a  class = "reference internal"  href = "index.html" > Tutorials< / a > < ul  class = "current" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "01-vector-add.html" > Vector Addition< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "02-fused-softmax.html" > Fused Softmax< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "03-matrix-multiplication.html" > Matrix Multiplication< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "04-low-memory-dropout.html" > Low-Memory Dropout< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< li  class = "toctree-l2 current" > < a  class = "current reference internal"  href = "#" > Layer Normalization< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< p  class = "caption"  role = "heading" > < span  class = "caption-text" > Python API< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../python-api/triton.html" > triton< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../python-api/triton.language.html" > triton.language< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../python-api/triton.testing.html" > triton.testing< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< p  class = "caption"  role = "heading" > < span  class = "caption-text" > Programming Guide< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../programming-guide/chapter-1/introduction.html" > Introduction< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../programming-guide/chapter-2/related-work.html" > Related Work< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < / nav > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < section  data-toggle = "wy-nav-shift"  class = "wy-nav-content-wrap" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      < nav  class = "wy-nav-top"  aria-label = "top navigation" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          < i  data-toggle = "wy-nav-top"  class = "fa fa-bars" > < / i > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          < a  href = "../../index.html" > Triton< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      < / nav > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      < div  class = "wy-nav-content" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < div  class = "rst-content" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< div  role = "navigation"  aria-label = "breadcrumbs navigation" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < ul  class = "wy-breadcrumbs" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      < li > < a  href = "../../index.html"  class = "icon icon-home" > < / a >  » < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          < li > < a  href = "index.html" > Tutorials< / a >  » < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      < li > Layer Normalization< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      < li  class = "wy-breadcrumbs-aside" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            < a  href = "../../_sources/getting-started/tutorials/05-layer-norm.rst.txt"  rel = "nofollow" >  View page source< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < hr / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          < div  role = "main"  class = "document"  itemscope = "itemscope"  itemtype = "http://schema.org/Article" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								           < div  itemprop = "articleBody" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < div  class = "sphx-glr-download-link-note admonition note" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< p  class = "admonition-title" > Note< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< p > Click < a  class = "reference internal"  href = "#sphx-glr-download-getting-started-tutorials-05-layer-norm-py" > < span  class = "std std-ref" > here< / span > < / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								to download the full example code< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< div  class = "sphx-glr-example-title section"  id = "layer-normalization" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  id = "sphx-glr-getting-started-tutorials-05-layer-norm-py" > < / span > < h1 > Layer Normalization< a  class = "headerlink"  href = "#layer-normalization"  title = "Permalink to this headline" > ¶< / a > < / h1 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< img  alt = "05 layer norm"  class = "sphx-glr-single-img"  src = "../../_images/sphx_glr_05-layer-norm_001.png"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< p  class = "sphx-glr-script-out" > Out:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< div  class = "sphx-glr-script-out highlight-none notranslate" > < div  class = "highlight" > < pre > < span > < / span > layer-norm-backward:
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          N      Triton       Torch        Apex
							 
						 
					
						
							
								
									
										
										
										
											2022-02-17 00:40:30 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								0    1024.0  307.200008   98.303995  307.200008
							 
						 
					
						
							
								
									
										
										
										
											2022-02-16 00:38:53 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								1    1536.0  347.773587  134.050910  341.333333
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								2    2048.0  420.102553  161.154101  334.367350
							 
						 
					
						
							
								
									
										
										
										
											2022-02-17 00:40:30 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								3    2560.0  455.111129  181.238943  330.322572
							 
						 
					
						
							
								
									
										
										
										
											2022-02-16 00:38:53 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								4    3072.0  511.999982  191.999993  320.556515
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								5    3584.0  551.384634  207.768111  310.527060
							 
						 
					
						
							
								
									
										
										
										
											2022-02-17 00:40:30 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								6    4096.0  568.231237  220.412561  298.796351
							 
						 
					
						
							
								
									
										
										
										
											2022-02-16 00:38:53 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								7    4608.0  504.986315  232.825259  286.507772
							 
						 
					
						
							
								
									
										
										
										
											2022-02-17 00:40:30 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								8    5120.0  527.381977  242.845844  284.444444
							 
						 
					
						
							
								
									
										
										
										
											2022-02-16 00:38:53 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								9    5632.0  542.843364  243.545956  289.438969
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								10   6144.0  546.133354  248.661056  286.879370
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								11   6656.0  532.479975  256.000009  285.767438
							 
						 
					
						
							
								
									
										
										
										
											2022-02-14 00:38:35 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								12   7168.0  507.469040  260.654538  286.242939
							 
						 
					
						
							
								
									
										
										
										
											2022-02-16 00:38:53 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								13   7680.0  479.999983  262.190612  278.850215
							 
						 
					
						
							
								
									
										
										
										
											2022-02-17 00:40:30 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								14   8192.0  462.607053  267.130429  284.939124
							 
						 
					
						
							
								
									
										
										
										
											2022-02-16 00:38:53 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								15   8704.0  417.791980  267.815384  284.987724
							 
						 
					
						
							
								
									
										
										
										
											2022-02-17 00:40:30 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								16   9216.0  430.319054  272.394084  288.751954
							 
						 
					
						
							
								
									
										
										
										
											2022-02-16 00:38:53 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								17   9728.0  438.857162  280.278512  290.027323
							 
						 
					
						
							
								
									
										
										
										
											2022-02-17 00:40:30 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								18  10240.0  449.287041  286.433562  290.153487
							 
						 
					
						
							
								
									
										
										
										
											2022-02-16 00:38:53 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								19  10752.0  426.525614  247.172406  290.594591
							 
						 
					
						
							
								
									
										
										
										
											2022-02-17 00:40:30 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								20  11264.0  426.397479  245.536784  286.676558
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								21  11776.0  422.457417  249.667843  288.686414
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								22  12288.0  419.504980  254.673582  294.029924
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								23  12800.0  413.458944  253.465340  289.538159
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								24  13312.0  411.181478  252.559690  289.916513
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								25  13824.0  404.112047  256.991469  292.313649
							 
						 
					
						
							
								
									
										
										
										
											2022-02-16 00:38:53 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								26  14336.0  393.215988  254.485198  286.719986
							 
						 
					
						
							
								
									
										
										
										
											2022-02-17 00:40:30 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								27  14848.0  385.245405  257.665934  289.246765
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								28  15360.0  373.874218  257.970599  287.326580
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								29  15872.0  371.274849  261.806182  289.899545
							 
						 
					
						
							
								
									
										
										
										
											2022-02-09 07:15:50 +00:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< div  class = "line-block" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< div  class = "line" > < br  / > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< div  class = "highlight-default notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "kn" > import< / span >  < span  class = "nn" > torch< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "kn" > import< / span >  < span  class = "nn" > triton< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "kn" > import< / span >  < span  class = "nn" > triton.language< / span >  < span  class = "k" > as< / span >  < span  class = "nn" > tl< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "k" > try< / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # should not be added to extras_require in setup.py.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "kn" > import< / span >  < span  class = "nn" > apex< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > HAS_APEX< / span >  < span  class = "o" > =< / span >  < span  class = "kc" > True< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "k" > except< / span >  < span  class = "ne" > ModuleNotFoundError< / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > HAS_APEX< / span >  < span  class = "o" > =< / span >  < span  class = "kc" > False< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "c1" > # Forward Pass< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "nd" > @triton< / span > < span  class = "o" > .< / span > < span  class = "n" > jit< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > _layer_norm_fwd_fused< / span > < span  class = "p" > (< / span > < span  class = "n" > X< / span > < span  class = "p" > ,< / span >  < span  class = "n" > Y< / span > < span  class = "p" > ,< / span >  < span  class = "n" > W< / span > < span  class = "p" > ,< / span >  < span  class = "n" > B< / span > < span  class = "p" > ,< / span >  < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > V< / span > < span  class = "p" > ,< / span >  < span  class = "n" > stride< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > ,< / span >  < span  class = "n" > eps< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								                          < span  class = "n" > BLOCK_SIZE< / span > < span  class = "p" > :< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > constexpr< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # position of elements processed by this program< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > row< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > program_id< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > cols< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > arange< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > mask< / span >  < span  class = "o" > =< / span >  < span  class = "n" > cols< / span >  < span  class = "o" > < < / span >  < span  class = "n" > N< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # offset data pointers to start at the row of interest< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > X< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > row< / span >  < span  class = "o" > *< / span >  < span  class = "n" > stride< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > Y< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > row< / span >  < span  class = "o" > *< / span >  < span  class = "n" > stride< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # load data and cast to float32< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > x< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > X< / span >  < span  class = "o" > +< / span >  < span  class = "n" > cols< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > ,< / span >  < span  class = "n" > other< / span > < span  class = "o" > =< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span > < span  class = "o" > .< / span > < span  class = "n" > to< / span > < span  class = "p" > (< / span > < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > float32< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # compute mean< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > mean< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > sum< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > axis< / span > < span  class = "o" > =< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span >  < span  class = "o" > /< / span >  < span  class = "n" > N< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # compute std< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > xmean< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > where< / span > < span  class = "p" > (< / span > < span  class = "n" > mask< / span > < span  class = "p" > ,< / span >  < span  class = "n" > x< / span >  < span  class = "o" > -< / span >  < span  class = "n" > mean< / span > < span  class = "p" > ,< / span >  < span  class = "mf" > 0.< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > var< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > sum< / span > < span  class = "p" > (< / span > < span  class = "n" > xmean< / span >  < span  class = "o" > *< / span >  < span  class = "n" > xmean< / span > < span  class = "p" > ,< / span >  < span  class = "n" > axis< / span > < span  class = "o" > =< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span >  < span  class = "o" > /< / span >  < span  class = "n" > N< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > rstd< / span >  < span  class = "o" > =< / span >  < span  class = "mi" > 1< / span >  < span  class = "o" > /< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > sqrt< / span > < span  class = "p" > (< / span > < span  class = "n" > var< / span >  < span  class = "o" > +< / span >  < span  class = "n" > eps< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > xhat< / span >  < span  class = "o" > =< / span >  < span  class = "n" > xmean< / span >  < span  class = "o" > *< / span >  < span  class = "n" > rstd< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # write-back mean/rstd< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > store< / span > < span  class = "p" > (< / span > < span  class = "n" > M< / span >  < span  class = "o" > +< / span >  < span  class = "n" > row< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mean< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > store< / span > < span  class = "p" > (< / span > < span  class = "n" > V< / span >  < span  class = "o" > +< / span >  < span  class = "n" > row< / span > < span  class = "p" > ,< / span >  < span  class = "n" > rstd< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # multiply by weight and add bias< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > w< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > W< / span >  < span  class = "o" > +< / span >  < span  class = "n" > cols< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > b< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > B< / span >  < span  class = "o" > +< / span >  < span  class = "n" > cols< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > y< / span >  < span  class = "o" > =< / span >  < span  class = "n" > xhat< / span >  < span  class = "o" > *< / span >  < span  class = "n" > w< / span >  < span  class = "o" > +< / span >  < span  class = "n" > b< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # write-back< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > store< / span > < span  class = "p" > (< / span > < span  class = "n" > Y< / span >  < span  class = "o" > +< / span >  < span  class = "n" > cols< / span > < span  class = "p" > ,< / span >  < span  class = "n" > y< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "c1" > # Backward pass (DX + partial DW + partial DB)< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "nd" > @triton< / span > < span  class = "o" > .< / span > < span  class = "n" > jit< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > _layer_norm_bwd_dx_fused< / span > < span  class = "p" > (< / span > < span  class = "n" > DX< / span > < span  class = "p" > ,< / span >  < span  class = "n" > DY< / span > < span  class = "p" > ,< / span >  < span  class = "n" > DW< / span > < span  class = "p" > ,< / span >  < span  class = "n" > DB< / span > < span  class = "p" > ,< / span >  < span  class = "n" > X< / span > < span  class = "p" > ,< / span >  < span  class = "n" > W< / span > < span  class = "p" > ,< / span >  < span  class = "n" > B< / span > < span  class = "p" > ,< / span >  < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > V< / span > < span  class = "p" > ,< / span >  < span  class = "n" > Lock< / span > < span  class = "p" > ,< / span >  < span  class = "n" > stride< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > ,< / span >  < span  class = "n" > eps< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								                             < span  class = "n" > GROUP_SIZE_M< / span > < span  class = "p" > :< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > constexpr< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > :< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > constexpr< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # position of elements processed by this program< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > row< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > program_id< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > cols< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > arange< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > mask< / span >  < span  class = "o" > =< / span >  < span  class = "n" > cols< / span >  < span  class = "o" > < < / span >  < span  class = "n" > N< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # offset data pointers to start at the row of interest< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > X< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > row< / span >  < span  class = "o" > *< / span >  < span  class = "n" > stride< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > DY< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > row< / span >  < span  class = "o" > *< / span >  < span  class = "n" > stride< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > DX< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > row< / span >  < span  class = "o" > *< / span >  < span  class = "n" > stride< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # offset locks and weight/bias gradient pointer< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # each kernel instance accumulates partial sums for< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # DW and DB into one of GROUP_SIZE_M independent buffers< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # these buffers stay in the L2, which allow this kernel< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # to be fast< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > lock_id< / span >  < span  class = "o" > =< / span >  < span  class = "n" > row< / span >  < span  class = "o" > %< / span >  < span  class = "n" > GROUP_SIZE_M< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > Lock< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > lock_id< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > Count< / span >  < span  class = "o" > =< / span >  < span  class = "n" > Lock< / span >  < span  class = "o" > +< / span >  < span  class = "n" > GROUP_SIZE_M< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > DW< / span >  < span  class = "o" > =< / span >  < span  class = "n" > DW< / span >  < span  class = "o" > +< / span >  < span  class = "n" > lock_id< / span >  < span  class = "o" > *< / span >  < span  class = "n" > N< / span >  < span  class = "o" > +< / span >  < span  class = "n" > cols< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > DB< / span >  < span  class = "o" > =< / span >  < span  class = "n" > DB< / span >  < span  class = "o" > +< / span >  < span  class = "n" > lock_id< / span >  < span  class = "o" > *< / span >  < span  class = "n" > N< / span >  < span  class = "o" > +< / span >  < span  class = "n" > cols< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # load data to SRAM< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > x< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > X< / span >  < span  class = "o" > +< / span >  < span  class = "n" > cols< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > ,< / span >  < span  class = "n" > other< / span > < span  class = "o" > =< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span > < span  class = "o" > .< / span > < span  class = "n" > to< / span > < span  class = "p" > (< / span > < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > float32< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > dy< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > DY< / span >  < span  class = "o" > +< / span >  < span  class = "n" > cols< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > ,< / span >  < span  class = "n" > other< / span > < span  class = "o" > =< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span > < span  class = "o" > .< / span > < span  class = "n" > to< / span > < span  class = "p" > (< / span > < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > float32< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > w< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > W< / span >  < span  class = "o" > +< / span >  < span  class = "n" > cols< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > )< / span > < span  class = "o" > .< / span > < span  class = "n" > to< / span > < span  class = "p" > (< / span > < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > float32< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > mean< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > M< / span >  < span  class = "o" > +< / span >  < span  class = "n" > row< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > rstd< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > V< / span >  < span  class = "o" > +< / span >  < span  class = "n" > row< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # compute dx< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > xhat< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > x< / span >  < span  class = "o" > -< / span >  < span  class = "n" > mean< / span > < span  class = "p" > )< / span >  < span  class = "o" > *< / span >  < span  class = "n" > rstd< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > wdy< / span >  < span  class = "o" > =< / span >  < span  class = "n" > w< / span >  < span  class = "o" > *< / span >  < span  class = "n" > dy< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > xhat< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > where< / span > < span  class = "p" > (< / span > < span  class = "n" > mask< / span > < span  class = "p" > ,< / span >  < span  class = "n" > xhat< / span > < span  class = "p" > ,< / span >  < span  class = "mf" > 0.< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > wdy< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > where< / span > < span  class = "p" > (< / span > < span  class = "n" > mask< / span > < span  class = "p" > ,< / span >  < span  class = "n" > wdy< / span > < span  class = "p" > ,< / span >  < span  class = "mf" > 0.< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > mean1< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > sum< / span > < span  class = "p" > (< / span > < span  class = "n" > xhat< / span >  < span  class = "o" > *< / span >  < span  class = "n" > wdy< / span > < span  class = "p" > ,< / span >  < span  class = "n" > axis< / span > < span  class = "o" > =< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span >  < span  class = "o" > /< / span >  < span  class = "n" > N< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > mean2< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > sum< / span > < span  class = "p" > (< / span > < span  class = "n" > wdy< / span > < span  class = "p" > ,< / span >  < span  class = "n" > axis< / span > < span  class = "o" > =< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span >  < span  class = "o" > /< / span >  < span  class = "n" > N< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > dx< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > wdy< / span >  < span  class = "o" > -< / span >  < span  class = "p" > (< / span > < span  class = "n" > xhat< / span >  < span  class = "o" > *< / span >  < span  class = "n" > mean1< / span >  < span  class = "o" > +< / span >  < span  class = "n" > mean2< / span > < span  class = "p" > ))< / span >  < span  class = "o" > *< / span >  < span  class = "n" > rstd< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # write-back dx< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > store< / span > < span  class = "p" > (< / span > < span  class = "n" > DX< / span >  < span  class = "o" > +< / span >  < span  class = "n" > cols< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dx< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # accumulate partial sums for dw/db< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > partial_dw< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > dy< / span >  < span  class = "o" > *< / span >  < span  class = "n" > xhat< / span > < span  class = "p" > )< / span > < span  class = "o" > .< / span > < span  class = "n" > to< / span > < span  class = "p" > (< / span > < span  class = "n" > w< / span > < span  class = "o" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > partial_db< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > dy< / span > < span  class = "p" > )< / span > < span  class = "o" > .< / span > < span  class = "n" > to< / span > < span  class = "p" > (< / span > < span  class = "n" > w< / span > < span  class = "o" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "k" > while< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > atomic_cas< / span > < span  class = "p" > (< / span > < span  class = "n" > Lock< / span > < span  class = "p" > ,< / span >  < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "mi" > 1< / span > < span  class = "p" > )< / span >  < span  class = "o" > ==< / span >  < span  class = "mi" > 1< / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "k" > pass< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > count< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > Count< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # first store doesn' t accumulate< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "k" > if< / span >  < span  class = "n" > count< / span >  < span  class = "o" > ==< / span >  < span  class = "mi" > 0< / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > atomic_xchg< / span > < span  class = "p" > (< / span > < span  class = "n" > Count< / span > < span  class = "p" > ,< / span >  < span  class = "mi" > 1< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "k" > else< / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > partial_dw< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > DW< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > partial_db< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > DB< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > store< / span > < span  class = "p" > (< / span > < span  class = "n" > DW< / span > < span  class = "p" > ,< / span >  < span  class = "n" > partial_dw< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > store< / span > < span  class = "p" > (< / span > < span  class = "n" > DB< / span > < span  class = "p" > ,< / span >  < span  class = "n" > partial_db< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # release lock< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > atomic_xchg< / span > < span  class = "p" > (< / span > < span  class = "n" > Lock< / span > < span  class = "p" > ,< / span >  < span  class = "mi" > 0< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "c1" > # Backward pass (total DW + total DB)< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "nd" > @triton< / span > < span  class = "o" > .< / span > < span  class = "n" > jit< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > _layer_norm_bwd_dwdb< / span > < span  class = "p" > (< / span > < span  class = "n" > DW< / span > < span  class = "p" > ,< / span >  < span  class = "n" > DB< / span > < span  class = "p" > ,< / span >  < span  class = "n" > FINAL_DW< / span > < span  class = "p" > ,< / span >  < span  class = "n" > FINAL_DB< / span > < span  class = "p" > ,< / span >  < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								                         < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > :< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > constexpr< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > :< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > constexpr< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > pid< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > program_id< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > cols< / span >  < span  class = "o" > =< / span >  < span  class = "n" > pid< / span >  < span  class = "o" > *< / span >  < span  class = "n" > BLOCK_SIZE_N< / span >  < span  class = "o" > +< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > arange< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > dw< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > zeros< / span > < span  class = "p" > ((< / span > < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > ),< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > float32< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > db< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > zeros< / span > < span  class = "p" > ((< / span > < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > ),< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > float32< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "k" > for< / span >  < span  class = "n" > i< / span >  < span  class = "ow" > in< / span >  < span  class = "nb" > range< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > rows< / span >  < span  class = "o" > =< / span >  < span  class = "n" > i< / span >  < span  class = "o" > +< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > arange< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > mask< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > rows< / span > < span  class = "p" > [:,< / span >  < span  class = "kc" > None< / span > < span  class = "p" > ]< / span >  < span  class = "o" > < < / span >  < span  class = "n" > M< / span > < span  class = "p" > )< / span >  < span  class = "o" > & < / span >  < span  class = "p" > (< / span > < span  class = "n" > cols< / span > < span  class = "p" > [< / span > < span  class = "kc" > None< / span > < span  class = "p" > ,< / span >  < span  class = "p" > :]< / span >  < span  class = "o" > < < / span >  < span  class = "n" > N< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > offs< / span >  < span  class = "o" > =< / span >  < span  class = "n" > rows< / span > < span  class = "p" > [:,< / span >  < span  class = "kc" > None< / span > < span  class = "p" > ]< / span >  < span  class = "o" > *< / span >  < span  class = "n" > N< / span >  < span  class = "o" > +< / span >  < span  class = "n" > cols< / span > < span  class = "p" > [< / span > < span  class = "kc" > None< / span > < span  class = "p" > ,< / span >  < span  class = "p" > :]< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > dw< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > DW< / span >  < span  class = "o" > +< / span >  < span  class = "n" > offs< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > ,< / span >  < span  class = "n" > other< / span > < span  class = "o" > =< / span > < span  class = "mf" > 0.< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > db< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > DB< / span >  < span  class = "o" > +< / span >  < span  class = "n" > offs< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > ,< / span >  < span  class = "n" > other< / span > < span  class = "o" > =< / span > < span  class = "mf" > 0.< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > sum_dw< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > sum< / span > < span  class = "p" > (< / span > < span  class = "n" > dw< / span > < span  class = "p" > ,< / span >  < span  class = "n" > axis< / span > < span  class = "o" > =< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > sum_db< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > sum< / span > < span  class = "p" > (< / span > < span  class = "n" > db< / span > < span  class = "p" > ,< / span >  < span  class = "n" > axis< / span > < span  class = "o" > =< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > store< / span > < span  class = "p" > (< / span > < span  class = "n" > FINAL_DW< / span >  < span  class = "o" > +< / span >  < span  class = "n" > cols< / span > < span  class = "p" > ,< / span >  < span  class = "n" > sum_dw< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > cols< / span >  < span  class = "o" > < < / span >  < span  class = "n" > N< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > store< / span > < span  class = "p" > (< / span > < span  class = "n" > FINAL_DB< / span >  < span  class = "o" > +< / span >  < span  class = "n" > cols< / span > < span  class = "p" > ,< / span >  < span  class = "n" > sum_db< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > cols< / span >  < span  class = "o" > < < / span >  < span  class = "n" > N< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "k" > class< / span >  < span  class = "nc" > LayerNorm< / span > < span  class = "p" > (< / span > < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > autograd< / span > < span  class = "o" > .< / span > < span  class = "n" > Function< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "nd" > @staticmethod< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "k" > def< / span >  < span  class = "nf" > forward< / span > < span  class = "p" > (< / span > < span  class = "n" > ctx< / span > < span  class = "p" > ,< / span >  < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > normalized_shape< / span > < span  class = "p" > ,< / span >  < span  class = "n" > weight< / span > < span  class = "p" > ,< / span >  < span  class = "n" > bias< / span > < span  class = "p" > ,< / span >  < span  class = "n" > eps< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "c1" > # allocate output< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > y< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > empty_like< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "c1" > # reshape input data into 2D tensor< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > x_arg< / span >  < span  class = "o" > =< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > reshape< / span > < span  class = "p" > (< / span > < span  class = "o" > -< / span > < span  class = "mi" > 1< / span > < span  class = "p" > ,< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > shape< / span > < span  class = "p" > [< / span > < span  class = "o" > -< / span > < span  class = "mi" > 1< / span > < span  class = "p" > ])< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span >  < span  class = "o" > =< / span >  < span  class = "n" > x_arg< / span > < span  class = "o" > .< / span > < span  class = "n" > shape< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > mean< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > empty< / span > < span  class = "p" > ((< / span > < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "p" > ),< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > float32< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > rstd< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > empty< / span > < span  class = "p" > ((< / span > < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "p" > ),< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > float32< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "c1" > # Less than 64KB per feature: enqueue fused kernel< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > MAX_FUSED_SIZE< / span >  < span  class = "o" > =< / span >  < span  class = "mi" > 65536< / span >  < span  class = "o" > //< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > element_size< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > BLOCK_SIZE< / span >  < span  class = "o" > =< / span >  < span  class = "nb" > min< / span > < span  class = "p" > (< / span > < span  class = "n" > MAX_FUSED_SIZE< / span > < span  class = "p" > ,< / span >  < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > next_power_of_2< / span > < span  class = "p" > (< / span > < span  class = "n" > N< / span > < span  class = "p" > ))< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "k" > if< / span >  < span  class = "n" > N< / span >  < span  class = "o" > > < / span >  < span  class = "n" > BLOCK_SIZE< / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            < span  class = "k" > raise< / span >  < span  class = "ne" > RuntimeError< / span > < span  class = "p" > (< / span > < span  class = "s2" > " This layer norm doesn' t support feature dim > = 64KB." < / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "c1" > # heuristics for number of warps< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > num_warps< / span >  < span  class = "o" > =< / span >  < span  class = "nb" > min< / span > < span  class = "p" > (< / span > < span  class = "nb" > max< / span > < span  class = "p" > (< / span > < span  class = "n" > BLOCK_SIZE< / span >  < span  class = "o" > //< / span >  < span  class = "mi" > 256< / span > < span  class = "p" > ,< / span >  < span  class = "mi" > 1< / span > < span  class = "p" > ),< / span >  < span  class = "mi" > 8< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "c1" > # enqueue kernel< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > _layer_norm_fwd_fused< / span > < span  class = "p" > [(< / span > < span  class = "n" > M< / span > < span  class = "p" > ,)](< / span > < span  class = "n" > x_arg< / span > < span  class = "p" > ,< / span >  < span  class = "n" > y< / span > < span  class = "p" > ,< / span >  < span  class = "n" > weight< / span > < span  class = "p" > ,< / span >  < span  class = "n" > bias< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mean< / span > < span  class = "p" > ,< / span >  < span  class = "n" > rstd< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								                                    < span  class = "n" > x_arg< / span > < span  class = "o" > .< / span > < span  class = "n" > stride< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ),< / span >  < span  class = "n" > N< / span > < span  class = "p" > ,< / span >  < span  class = "n" > eps< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								                                    < span  class = "n" > BLOCK_SIZE< / span > < span  class = "o" > =< / span > < span  class = "n" > BLOCK_SIZE< / span > < span  class = "p" > ,< / span >  < span  class = "n" > num_warps< / span > < span  class = "o" > =< / span > < span  class = "n" > num_warps< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > ctx< / span > < span  class = "o" > .< / span > < span  class = "n" > save_for_backward< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > weight< / span > < span  class = "p" > ,< / span >  < span  class = "n" > bias< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mean< / span > < span  class = "p" > ,< / span >  < span  class = "n" > rstd< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > ctx< / span > < span  class = "o" > .< / span > < span  class = "n" > BLOCK_SIZE< / span >  < span  class = "o" > =< / span >  < span  class = "n" > BLOCK_SIZE< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > ctx< / span > < span  class = "o" > .< / span > < span  class = "n" > num_warps< / span >  < span  class = "o" > =< / span >  < span  class = "n" > num_warps< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > ctx< / span > < span  class = "o" > .< / span > < span  class = "n" > eps< / span >  < span  class = "o" > =< / span >  < span  class = "n" > eps< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "k" > return< / span >  < span  class = "n" > y< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "nd" > @staticmethod< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "k" > def< / span >  < span  class = "nf" > backward< / span > < span  class = "p" > (< / span > < span  class = "n" > ctx< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dy< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > w< / span > < span  class = "p" > ,< / span >  < span  class = "n" > b< / span > < span  class = "p" > ,< / span >  < span  class = "n" > m< / span > < span  class = "p" > ,< / span >  < span  class = "n" > v< / span >  < span  class = "o" > =< / span >  < span  class = "n" > ctx< / span > < span  class = "o" > .< / span > < span  class = "n" > saved_tensors< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "c1" > # heuristics for amount of parallel reduction stream for DG/DB< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > N< / span >  < span  class = "o" > =< / span >  < span  class = "n" > w< / span > < span  class = "o" > .< / span > < span  class = "n" > shape< / span > < span  class = "p" > [< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ]< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > GROUP_SIZE_M< / span >  < span  class = "o" > =< / span >  < span  class = "mi" > 64< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "k" > if< / span >  < span  class = "n" > N< / span >  < span  class = "o" > < =< / span >  < span  class = "mi" > 8192< / span > < span  class = "p" > :< / span >  < span  class = "n" > GROUP_SIZE_M< / span >  < span  class = "o" > =< / span >  < span  class = "mi" > 96< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "k" > if< / span >  < span  class = "n" > N< / span >  < span  class = "o" > < =< / span >  < span  class = "mi" > 4096< / span > < span  class = "p" > :< / span >  < span  class = "n" > GROUP_SIZE_M< / span >  < span  class = "o" > =< / span >  < span  class = "mi" > 128< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "k" > if< / span >  < span  class = "n" > N< / span >  < span  class = "o" > < =< / span >  < span  class = "mi" > 1024< / span > < span  class = "p" > :< / span >  < span  class = "n" > GROUP_SIZE_M< / span >  < span  class = "o" > =< / span >  < span  class = "mi" > 256< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "c1" > # allocate output< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > locks< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > zeros< / span > < span  class = "p" > (< / span > < span  class = "mi" > 2< / span >  < span  class = "o" > *< / span >  < span  class = "n" > GROUP_SIZE_M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > int32< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > _dw< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > empty< / span > < span  class = "p" > ((< / span > < span  class = "n" > GROUP_SIZE_M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > w< / span > < span  class = "o" > .< / span > < span  class = "n" > shape< / span > < span  class = "p" > [< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ]),< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "n" > w< / span > < span  class = "o" > .< / span > < span  class = "n" > device< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > _db< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > empty< / span > < span  class = "p" > ((< / span > < span  class = "n" > GROUP_SIZE_M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > w< / span > < span  class = "o" > .< / span > < span  class = "n" > shape< / span > < span  class = "p" > [< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ]),< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "n" > w< / span > < span  class = "o" > .< / span > < span  class = "n" > device< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > dw< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > empty< / span > < span  class = "p" > ((< / span > < span  class = "n" > w< / span > < span  class = "o" > .< / span > < span  class = "n" > shape< / span > < span  class = "p" > [< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ],),< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > w< / span > < span  class = "o" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "n" > w< / span > < span  class = "o" > .< / span > < span  class = "n" > device< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > db< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > empty< / span > < span  class = "p" > ((< / span > < span  class = "n" > w< / span > < span  class = "o" > .< / span > < span  class = "n" > shape< / span > < span  class = "p" > [< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ],),< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > w< / span > < span  class = "o" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "n" > w< / span > < span  class = "o" > .< / span > < span  class = "n" > device< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > dx< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > empty_like< / span > < span  class = "p" > (< / span > < span  class = "n" > dy< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "c1" > # enqueue kernel using forward pass heuristics< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "c1" > # also compute partial sums for DW and DB< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > x_arg< / span >  < span  class = "o" > =< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > reshape< / span > < span  class = "p" > (< / span > < span  class = "o" > -< / span > < span  class = "mi" > 1< / span > < span  class = "p" > ,< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > shape< / span > < span  class = "p" > [< / span > < span  class = "o" > -< / span > < span  class = "mi" > 1< / span > < span  class = "p" > ])< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span >  < span  class = "o" > =< / span >  < span  class = "n" > x_arg< / span > < span  class = "o" > .< / span > < span  class = "n" > shape< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > _layer_norm_bwd_dx_fused< / span > < span  class = "p" > [(< / span > < span  class = "n" > M< / span > < span  class = "p" > ,)](< / span > < span  class = "n" > dx< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dy< / span > < span  class = "p" > ,< / span >  < span  class = "n" > _dw< / span > < span  class = "p" > ,< / span >  < span  class = "n" > _db< / span > < span  class = "p" > ,< / span >  < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > w< / span > < span  class = "p" > ,< / span >  < span  class = "n" > b< / span > < span  class = "p" > ,< / span >  < span  class = "n" > m< / span > < span  class = "p" > ,< / span >  < span  class = "n" > v< / span > < span  class = "p" > ,< / span >  < span  class = "n" > locks< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								                                       < span  class = "n" > x_arg< / span > < span  class = "o" > .< / span > < span  class = "n" > stride< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ),< / span >  < span  class = "n" > N< / span > < span  class = "p" > ,< / span >  < span  class = "n" > ctx< / span > < span  class = "o" > .< / span > < span  class = "n" > eps< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								                                       < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "o" > =< / span > < span  class = "n" > ctx< / span > < span  class = "o" > .< / span > < span  class = "n" > BLOCK_SIZE< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								                                       < span  class = "n" > GROUP_SIZE_M< / span > < span  class = "o" > =< / span > < span  class = "n" > GROUP_SIZE_M< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								                                       < span  class = "n" > num_warps< / span > < span  class = "o" > =< / span > < span  class = "n" > ctx< / span > < span  class = "o" > .< / span > < span  class = "n" > num_warps< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > grid< / span >  < span  class = "o" > =< / span >  < span  class = "k" > lambda< / span >  < span  class = "n" > meta< / span > < span  class = "p" > :< / span >  < span  class = "p" > [< / span > < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > cdiv< / span > < span  class = "p" > (< / span > < span  class = "n" > N< / span > < span  class = "p" > ,< / span >  < span  class = "n" > meta< / span > < span  class = "p" > [< / span > < span  class = "s1" > ' BLOCK_SIZE_N' < / span > < span  class = "p" > ])]< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "c1" > # accumulate partial sums in separate kernel< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > _layer_norm_bwd_dwdb< / span > < span  class = "p" > [< / span > < span  class = "n" > grid< / span > < span  class = "p" > ](< / span > < span  class = "n" > _dw< / span > < span  class = "p" > ,< / span >  < span  class = "n" > _db< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dw< / span > < span  class = "p" > ,< / span >  < span  class = "n" > db< / span > < span  class = "p" > ,< / span >  < span  class = "n" > GROUP_SIZE_M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								                                   < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "o" > =< / span > < span  class = "mi" > 32< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								                                   < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "o" > =< / span > < span  class = "mi" > 128< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "k" > return< / span >  < span  class = "n" > dx< / span > < span  class = "p" > ,< / span >  < span  class = "kc" > None< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dw< / span > < span  class = "p" > ,< / span >  < span  class = "n" > db< / span > < span  class = "p" > ,< / span >  < span  class = "kc" > None< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "n" > layer_norm< / span >  < span  class = "o" > =< / span >  < span  class = "n" > LayerNorm< / span > < span  class = "o" > .< / span > < span  class = "n" > apply< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > test_layer_norm< / span > < span  class = "p" > (< / span > < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dtype< / span > < span  class = "p" > ,< / span >  < span  class = "n" > eps< / span > < span  class = "o" > =< / span > < span  class = "mf" > 1e-5< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # create data< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > x_shape< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > w_shape< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > x_shape< / span > < span  class = "p" > [< / span > < span  class = "o" > -< / span > < span  class = "mi" > 1< / span > < span  class = "p" > ],< / span >  < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > weight< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > rand< / span > < span  class = "p" > (< / span > < span  class = "n" > w_shape< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > ,< / span >  < span  class = "n" > requires_grad< / span > < span  class = "o" > =< / span > < span  class = "kc" > True< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > bias< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > rand< / span > < span  class = "p" > (< / span > < span  class = "n" > w_shape< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > ,< / span >  < span  class = "n" > requires_grad< / span > < span  class = "o" > =< / span > < span  class = "kc" > True< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > x< / span >  < span  class = "o" > =< / span >  < span  class = "o" > -< / span > < span  class = "mf" > 2.3< / span >  < span  class = "o" > +< / span >  < span  class = "mf" > 0.5< / span >  < span  class = "o" > *< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > randn< / span > < span  class = "p" > (< / span > < span  class = "n" > x_shape< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > dy< / span >  < span  class = "o" > =< / span >  < span  class = "mf" > .1< / span >  < span  class = "o" > *< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > randn_like< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > requires_grad_< / span > < span  class = "p" > (< / span > < span  class = "kc" > True< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # forward pass< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > y_tri< / span >  < span  class = "o" > =< / span >  < span  class = "n" > layer_norm< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > w_shape< / span > < span  class = "p" > ,< / span >  < span  class = "n" > weight< / span > < span  class = "p" > ,< / span >  < span  class = "n" > bias< / span > < span  class = "p" > ,< / span >  < span  class = "n" > eps< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > y_ref< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > nn< / span > < span  class = "o" > .< / span > < span  class = "n" > functional< / span > < span  class = "o" > .< / span > < span  class = "n" > layer_norm< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > w_shape< / span > < span  class = "p" > ,< / span >  < span  class = "n" > weight< / span > < span  class = "p" > ,< / span >  < span  class = "n" > bias< / span > < span  class = "p" > ,< / span >  < span  class = "n" > eps< / span > < span  class = "p" > )< / span > < span  class = "o" > .< / span > < span  class = "n" > to< / span > < span  class = "p" > (< / span > < span  class = "n" > dtype< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # backward pass (triton)< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > y_tri< / span > < span  class = "o" > .< / span > < span  class = "n" > backward< / span > < span  class = "p" > (< / span > < span  class = "n" > dy< / span > < span  class = "p" > ,< / span >  < span  class = "n" > retain_graph< / span > < span  class = "o" > =< / span > < span  class = "kc" > True< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > dx_tri< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dw_tri< / span > < span  class = "p" > ,< / span >  < span  class = "n" > db_tri< / span >  < span  class = "o" > =< / span >  < span  class = "p" > [< / span > < span  class = "n" > _< / span > < span  class = "o" > .< / span > < span  class = "n" > grad< / span > < span  class = "o" > .< / span > < span  class = "n" > clone< / span > < span  class = "p" > ()< / span >  < span  class = "k" > for< / span >  < span  class = "n" > _< / span >  < span  class = "ow" > in< / span >  < span  class = "p" > [< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > weight< / span > < span  class = "p" > ,< / span >  < span  class = "n" > bias< / span > < span  class = "p" > ]]< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > grad< / span > < span  class = "p" > ,< / span >  < span  class = "n" > weight< / span > < span  class = "o" > .< / span > < span  class = "n" > grad< / span > < span  class = "p" > ,< / span >  < span  class = "n" > bias< / span > < span  class = "o" > .< / span > < span  class = "n" > grad< / span >  < span  class = "o" > =< / span >  < span  class = "kc" > None< / span > < span  class = "p" > ,< / span >  < span  class = "kc" > None< / span > < span  class = "p" > ,< / span >  < span  class = "kc" > None< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # backward pass (torch)< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > y_ref< / span > < span  class = "o" > .< / span > < span  class = "n" > backward< / span > < span  class = "p" > (< / span > < span  class = "n" > dy< / span > < span  class = "p" > ,< / span >  < span  class = "n" > retain_graph< / span > < span  class = "o" > =< / span > < span  class = "kc" > True< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > dx_ref< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dw_ref< / span > < span  class = "p" > ,< / span >  < span  class = "n" > db_ref< / span >  < span  class = "o" > =< / span >  < span  class = "p" > [< / span > < span  class = "n" > _< / span > < span  class = "o" > .< / span > < span  class = "n" > grad< / span > < span  class = "o" > .< / span > < span  class = "n" > clone< / span > < span  class = "p" > ()< / span >  < span  class = "k" > for< / span >  < span  class = "n" > _< / span >  < span  class = "ow" > in< / span >  < span  class = "p" > [< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > weight< / span > < span  class = "p" > ,< / span >  < span  class = "n" > bias< / span > < span  class = "p" > ]]< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # compare< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > assert_almost_equal< / span > < span  class = "p" > (< / span > < span  class = "n" > y_tri< / span > < span  class = "p" > ,< / span >  < span  class = "n" > y_ref< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > assert_almost_equal< / span > < span  class = "p" > (< / span > < span  class = "n" > dx_tri< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dx_ref< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > assert_almost_equal< / span > < span  class = "p" > (< / span > < span  class = "n" > db_tri< / span > < span  class = "p" > ,< / span >  < span  class = "n" > db_ref< / span > < span  class = "p" > ,< / span >  < span  class = "n" > decimal< / span > < span  class = "o" > =< / span > < span  class = "mi" > 1< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > assert_almost_equal< / span > < span  class = "p" > (< / span > < span  class = "n" > dw_tri< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dw_ref< / span > < span  class = "p" > ,< / span >  < span  class = "n" > decimal< / span > < span  class = "o" > =< / span > < span  class = "mi" > 1< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "nd" > @triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > perf_report< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > Benchmark< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > x_names< / span > < span  class = "o" > =< / span > < span  class = "p" > [< / span > < span  class = "s1" > ' N' < / span > < span  class = "p" > ],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > x_vals< / span > < span  class = "o" > =< / span > < span  class = "p" > [< / span > < span  class = "mi" > 512< / span >  < span  class = "o" > *< / span >  < span  class = "n" > i< / span >  < span  class = "k" > for< / span >  < span  class = "n" > i< / span >  < span  class = "ow" > in< / span >  < span  class = "nb" > range< / span > < span  class = "p" > (< / span > < span  class = "mi" > 2< / span > < span  class = "p" > ,< / span >  < span  class = "mi" > 32< / span > < span  class = "p" > )],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > line_arg< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' provider' < / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > line_vals< / span > < span  class = "o" > =< / span > < span  class = "p" > [< / span > < span  class = "s1" > ' triton' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' torch' < / span > < span  class = "p" > ]< / span >  < span  class = "o" > +< / span >  < span  class = "p" > ([< / span > < span  class = "s1" > ' apex' < / span > < span  class = "p" > ]< / span >  < span  class = "k" > if< / span >  < span  class = "n" > HAS_APEX< / span >  < span  class = "k" > else< / span >  < span  class = "p" > []),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > line_names< / span > < span  class = "o" > =< / span > < span  class = "p" > [< / span > < span  class = "s1" > ' Triton' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' Torch' < / span > < span  class = "p" > ]< / span >  < span  class = "o" > +< / span >  < span  class = "p" > ([< / span > < span  class = "s1" > ' Apex' < / span > < span  class = "p" > ]< / span >  < span  class = "k" > if< / span >  < span  class = "n" > HAS_APEX< / span >  < span  class = "k" > else< / span >  < span  class = "p" > []),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > styles< / span > < span  class = "o" > =< / span > < span  class = "p" > [(< / span > < span  class = "s1" > ' blue' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' -' < / span > < span  class = "p" > ),< / span >  < span  class = "p" > (< / span > < span  class = "s1" > ' green' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' -' < / span > < span  class = "p" > ),< / span >  < span  class = "p" > (< / span > < span  class = "s1" > ' orange' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' -' < / span > < span  class = "p" > )],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > ylabel< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' GB/s' < / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > plot_name< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' layer-norm-backward' < / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > args< / span > < span  class = "o" > =< / span > < span  class = "p" > {< / span > < span  class = "s1" > ' M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 4096< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' dtype' < / span > < span  class = "p" > :< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > float16< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' mode' < / span > < span  class = "p" > :< / span >  < span  class = "s1" > ' backward' < / span > < span  class = "p" > }< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > bench_layer_norm< / span > < span  class = "p" > (< / span > < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dtype< / span > < span  class = "p" > ,< / span >  < span  class = "n" > provider< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mode< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' backward' < / span > < span  class = "p" > ,< / span >  < span  class = "n" > eps< / span > < span  class = "o" > =< / span > < span  class = "mf" > 1e-5< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # create data< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > x_shape< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > w_shape< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > x_shape< / span > < span  class = "p" > [< / span > < span  class = "o" > -< / span > < span  class = "mi" > 1< / span > < span  class = "p" > ],< / span >  < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > weight< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > rand< / span > < span  class = "p" > (< / span > < span  class = "n" > w_shape< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > ,< / span >  < span  class = "n" > requires_grad< / span > < span  class = "o" > =< / span > < span  class = "kc" > True< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > bias< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > rand< / span > < span  class = "p" > (< / span > < span  class = "n" > w_shape< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > ,< / span >  < span  class = "n" > requires_grad< / span > < span  class = "o" > =< / span > < span  class = "kc" > True< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > x< / span >  < span  class = "o" > =< / span >  < span  class = "o" > -< / span > < span  class = "mf" > 2.3< / span >  < span  class = "o" > +< / span >  < span  class = "mf" > 0.5< / span >  < span  class = "o" > *< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > randn< / span > < span  class = "p" > (< / span > < span  class = "n" > x_shape< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > dtype< / span > < span  class = "p" > ,< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > dy< / span >  < span  class = "o" > =< / span >  < span  class = "mf" > .1< / span >  < span  class = "o" > *< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > randn_like< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > requires_grad_< / span > < span  class = "p" > (< / span > < span  class = "kc" > True< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # utility functions< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "k" > if< / span >  < span  class = "n" > provider< / span >  < span  class = "o" > ==< / span >  < span  class = "s1" > ' triton' < / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > y_fwd< / span >  < span  class = "o" > =< / span >  < span  class = "k" > lambda< / span > < span  class = "p" > :< / span >  < span  class = "n" > layer_norm< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > w_shape< / span > < span  class = "p" > ,< / span >  < span  class = "n" > weight< / span > < span  class = "p" > ,< / span >  < span  class = "n" > bias< / span > < span  class = "p" > ,< / span >  < span  class = "n" > eps< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "k" > if< / span >  < span  class = "n" > provider< / span >  < span  class = "o" > ==< / span >  < span  class = "s1" > ' torch' < / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > y_fwd< / span >  < span  class = "o" > =< / span >  < span  class = "k" > lambda< / span > < span  class = "p" > :< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > nn< / span > < span  class = "o" > .< / span > < span  class = "n" > functional< / span > < span  class = "o" > .< / span > < span  class = "n" > layer_norm< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > w_shape< / span > < span  class = "p" > ,< / span >  < span  class = "n" > weight< / span > < span  class = "p" > ,< / span >  < span  class = "n" > bias< / span > < span  class = "p" > ,< / span >  < span  class = "n" > eps< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "k" > if< / span >  < span  class = "n" > provider< / span >  < span  class = "o" > ==< / span >  < span  class = "s1" > ' apex' < / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > apex_layer_norm< / span >  < span  class = "o" > =< / span >  < span  class = "n" > apex< / span > < span  class = "o" > .< / span > < span  class = "n" > normalization< / span > < span  class = "o" > .< / span > < span  class = "n" > FusedLayerNorm< / span > < span  class = "p" > (< / span > < span  class = "n" > w_shape< / span > < span  class = "p" > )< / span > < span  class = "o" > .< / span > < span  class = "n" > to< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > device< / span > < span  class = "p" > )< / span > < span  class = "o" > .< / span > < span  class = "n" > to< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > y_fwd< / span >  < span  class = "o" > =< / span >  < span  class = "k" > lambda< / span > < span  class = "p" > :< / span >  < span  class = "n" > apex_layer_norm< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # forward pass< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "k" > if< / span >  < span  class = "n" > mode< / span >  < span  class = "o" > ==< / span >  < span  class = "s1" > ' forward' < / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > gbps< / span >  < span  class = "o" > =< / span >  < span  class = "k" > lambda< / span >  < span  class = "n" > ms< / span > < span  class = "p" > :< / span >  < span  class = "mi" > 2< / span >  < span  class = "o" > *< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > numel< / span > < span  class = "p" > ()< / span >  < span  class = "o" > *< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > element_size< / span > < span  class = "p" > ()< / span >  < span  class = "o" > /< / span >  < span  class = "n" > ms< / span >  < span  class = "o" > *< / span >  < span  class = "mf" > 1e-6< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > ms< / span > < span  class = "p" > ,< / span >  < span  class = "n" > min_ms< / span > < span  class = "p" > ,< / span >  < span  class = "n" > max_ms< / span >  < span  class = "o" > =< / span >  < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > do_bench< / span > < span  class = "p" > (< / span > < span  class = "n" > y_fwd< / span > < span  class = "p" > ,< / span >  < span  class = "n" > rep< / span > < span  class = "o" > =< / span > < span  class = "mi" > 500< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "c1" > # backward pass< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "k" > if< / span >  < span  class = "n" > mode< / span >  < span  class = "o" > ==< / span >  < span  class = "s1" > ' backward' < / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > gbps< / span >  < span  class = "o" > =< / span >  < span  class = "k" > lambda< / span >  < span  class = "n" > ms< / span > < span  class = "p" > :< / span >  < span  class = "mi" > 3< / span >  < span  class = "o" > *< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > numel< / span > < span  class = "p" > ()< / span >  < span  class = "o" > *< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > element_size< / span > < span  class = "p" > ()< / span >  < span  class = "o" > /< / span >  < span  class = "n" > ms< / span >  < span  class = "o" > *< / span >  < span  class = "mf" > 1e-6< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > y< / span >  < span  class = "o" > =< / span >  < span  class = "n" > y_fwd< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "n" > ms< / span > < span  class = "p" > ,< / span >  < span  class = "n" > min_ms< / span > < span  class = "p" > ,< / span >  < span  class = "n" > max_ms< / span >  < span  class = "o" > =< / span >  < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > do_bench< / span > < span  class = "p" > (< / span > < span  class = "k" > lambda< / span > < span  class = "p" > :< / span >  < span  class = "n" > y< / span > < span  class = "o" > .< / span > < span  class = "n" > backward< / span > < span  class = "p" > (< / span > < span  class = "n" > dy< / span > < span  class = "p" > ,< / span >  < span  class = "n" > retain_graph< / span > < span  class = "o" > =< / span > < span  class = "kc" > True< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								                                                     < span  class = "n" > grad_to_none< / span > < span  class = "o" > =< / span > < span  class = "p" > [< / span > < span  class = "n" > x< / span > < span  class = "p" > ],< / span >  < span  class = "n" > rep< / span > < span  class = "o" > =< / span > < span  class = "mi" > 500< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "k" > return< / span >  < span  class = "n" > gbps< / span > < span  class = "p" > (< / span > < span  class = "n" > ms< / span > < span  class = "p" > ),< / span >  < span  class = "n" > gbps< / span > < span  class = "p" > (< / span > < span  class = "n" > max_ms< / span > < span  class = "p" > ),< / span >  < span  class = "n" > gbps< / span > < span  class = "p" > (< / span > < span  class = "n" > min_ms< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< span  class = "n" > bench_layer_norm< / span > < span  class = "o" > .< / span > < span  class = "n" > run< / span > < span  class = "p" > (< / span > < span  class = "n" > save_path< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' .' < / span > < span  class = "p" > ,< / span >  < span  class = "n" > print_data< / span > < span  class = "o" > =< / span > < span  class = "kc" > True< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / div > 
							 
						 
					
						
							
								
									
										
										
										
											2022-02-17 00:40:30 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								< p  class = "sphx-glr-timing" > < strong > Total running time of the script:< / strong >  ( 2 minutes  11.405 seconds)< / p > 
							 
						 
					
						
							
								
									
										
										
										
											2022-02-09 07:15:50 +00:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								< div  class = "sphx-glr-footer class sphx-glr-footer-example docutils container"  id = "sphx-glr-download-getting-started-tutorials-05-layer-norm-py" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< div  class = "sphx-glr-download sphx-glr-download-python docutils container" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< p > < a  class = "reference download internal"  download = ""  href = "../../_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py" > < code  class = "xref download docutils literal notranslate" > < span  class = "pre" > Download< / span >  < span  class = "pre" > Python< / span >  < span  class = "pre" > source< / span >  < span  class = "pre" > code:< / span >  < span  class = "pre" > 05-layer-norm.py< / span > < / code > < / a > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< div  class = "sphx-glr-download sphx-glr-download-jupyter docutils container" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< p > < a  class = "reference download internal"  download = ""  href = "../../_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb" > < code  class = "xref download docutils literal notranslate" > < span  class = "pre" > Download< / span >  < span  class = "pre" > Jupyter< / span >  < span  class = "pre" > notebook:< / span >  < span  class = "pre" > 05-layer-norm.ipynb< / span > < / code > < / a > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< p  class = "sphx-glr-signature" > < a  class = "reference external"  href = "https://sphinx-gallery.github.io" > Gallery generated by Sphinx-Gallery< / a > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								           < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								           
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          < footer > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < div  class = "rst-footer-buttons"  role = "navigation"  aria-label = "footer navigation" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < a  href = "../../python-api/triton.html"  class = "btn btn-neutral float-right"  title = "triton"  accesskey = "n"  rel = "next" > Next < span  class = "fa fa-arrow-circle-right"  aria-hidden = "true" > < / span > < / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < a  href = "04-low-memory-dropout.html"  class = "btn btn-neutral float-left"  title = "Low-Memory Dropout"  accesskey = "p"  rel = "prev" > < span  class = "fa fa-arrow-circle-left"  aria-hidden = "true" > < / span >  Previous< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < hr / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < div  role = "contentinfo" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        ©  Copyright 2020, Philippe Tillet.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    Built with < a  href = "https://www.sphinx-doc.org/" > Sphinx< / a >  using a
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < a  href = "https://github.com/readthedocs/sphinx_rtd_theme" > theme< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    provided by < a  href = "https://readthedocs.org" > Read the Docs< / a > . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / footer > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< div  class = "rst-versions"  data-toggle = "rst-versions"  role = "note"  aria-label = "versions" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < span  class = "rst-current-version"  data-toggle = "rst-current-version" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "fa fa-book" >  Other Versions< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        v: master
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < span  class = "fa fa-caret-down" > < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < div  class = "rst-other-versions" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < dl > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            < dt > Tags< / dt > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            < dd > < a  href = "../../../v1.1.2/index.html" > v1.1.2< / a > < / dd > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < / dl > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < dl > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            < dt > Branches< / dt > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            < dd > < a  href = "05-layer-norm.html" > master< / a > < / dd > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        < / dl > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < script  type = "text/javascript" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      jQuery(function () {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								          SphinxRtdTheme.Navigation.enable(true);
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      });
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								   
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / body > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								< / html >