2022-02-09 07:15:50 +00:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								<!DOCTYPE html> 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< html  class = "writer-html5"  lang = "en"  > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< head > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < meta  charset = "utf-8"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < meta  name = "viewport"  content = "width=device-width, initial-scale=1.0"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < title > Low-Memory Dropout —  Triton  documentation< / title > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/css/theme.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/pygments.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/pygments.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/css/theme.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/gallery.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/gallery-binder.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/gallery-dataframe.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/gallery-rendered-html.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/css/custom.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  <!-- [if lt IE 9]>
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < script  src = "../../_static/js/html5shiv.min.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  <![endif]--> 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < script  type = "text/javascript"  id = "documentation_options"  data-url_root = "../../"  src = "../../_static/documentation_options.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < script  data-url_root = "../../"  id = "documentation_options"  src = "../../_static/documentation_options.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < script  src = "../../_static/jquery.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < script  src = "../../_static/underscore.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < script  src = "../../_static/doctools.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < script  async = "async"  src = "https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < script  type = "text/javascript"  src = "../../_static/js/theme.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < link  rel = "index"  title = "Index"  href = "../../genindex.html"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < link  rel = "search"  title = "Search"  href = "../../search.html"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < link  rel = "next"  title = "Layer Normalization"  href = "05-layer-norm.html"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < link  rel = "prev"  title = "Matrix Multiplication"  href = "03-matrix-multiplication.html"  / >  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / head > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< body  class = "wy-body-for-nav" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								   
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  class = "wy-grid-for-nav" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < nav  data-toggle = "wy-nav-shift"  class = "wy-nav-side" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < div  class = "wy-side-scroll" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < div  class = "wy-side-nav-search"  > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < a  href = "../../index.html"  class = "icon icon-home" >  Triton
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  role = "search" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < form  id = "rtd-search-form"  class = "wy-form"  action = "../../search.html"  method = "get" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < input  type = "text"  name = "q"  placeholder = "Search docs"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < input  type = "hidden"  name = "check_keywords"  value = "yes"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < input  type = "hidden"  name = "area"  value = "default"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / form > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < div  class = "wy-menu wy-menu-vertical"  data-spy = "affix"  role = "navigation"  aria-label = "main navigation" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              < p  class = "caption"  role = "heading" > < span  class = "caption-text" > Getting Started< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul  class = "current" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../installation.html" > Installation< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1 current" > < a  class = "reference internal"  href = "index.html" > Tutorials< / a > < ul  class = "current" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "01-vector-add.html" > Vector Addition< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "02-fused-softmax.html" > Fused Softmax< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "03-matrix-multiplication.html" > Matrix Multiplication< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2 current" > < a  class = "current reference internal"  href = "#" > Low-Memory Dropout< / a > < ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "#baseline" > Baseline< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "#seeded-dropout" > Seeded dropout< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "#exercises" > Exercises< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "#references" > References< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "05-layer-norm.html" > Layer Normalization< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "caption"  role = "heading" > < span  class = "caption-text" > Python API< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../python-api/triton.html" > triton< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../python-api/triton.language.html" > triton.language< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../python-api/triton.testing.html" > triton.testing< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "caption"  role = "heading" > < span  class = "caption-text" > Programming Guide< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../programming-guide/chapter-1/introduction.html" > Introduction< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../programming-guide/chapter-2/related-work.html" > Related Work< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / nav > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < section  data-toggle = "wy-nav-shift"  class = "wy-nav-content-wrap" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < nav  class = "wy-nav-top"  aria-label = "top navigation" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < i  data-toggle = "wy-nav-top"  class = "fa fa-bars" > < / i > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < a  href = "../../index.html" > Triton< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < / nav > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < div  class = "wy-nav-content" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < div  class = "rst-content" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  role = "navigation"  aria-label = "breadcrumbs navigation" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < ul  class = "wy-breadcrumbs" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < li > < a  href = "../../index.html"  class = "icon icon-home" > < / a >  » < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < li > < a  href = "index.html" > Tutorials< / a >  » < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < li > Low-Memory Dropout< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < li  class = "wy-breadcrumbs-aside" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < a  href = "../../_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt"  rel = "nofollow" >  View page source< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < hr / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < div  role = "main"  class = "document"  itemscope = "itemscope"  itemtype = "http://schema.org/Article" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								           < div  itemprop = "articleBody" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  class = "sphx-glr-download-link-note admonition note" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "admonition-title" > Note< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Click < a  class = "reference internal"  href = "#sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py" > < span  class = "std std-ref" > here< / span > < / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								to download the full example code< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "sphx-glr-example-title section"  id = "low-memory-dropout" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  id = "sphx-glr-getting-started-tutorials-04-low-memory-dropout-py" > < / span > < h1 > Low-Memory Dropout< a  class = "headerlink"  href = "#low-memory-dropout"  title = "Permalink to this headline" > ¶< / a > < / h1 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > In this tutorial, you will write a memory-efficient implementation of dropout whose state
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								will be composed of a single int32 seed. This differs from more traditional implementations of dropout,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul  class = "simple" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > The limitations of naive implementations of Dropout with PyTorch< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Parallel pseudo-random number generation in Triton< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "section"  id = "baseline" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > Baseline< a  class = "headerlink"  href = "#baseline"  title = "Permalink to this headline" > ¶< / a > < / h2 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > The < em > dropout< / em >  operator was first introduced in < a  class = "reference internal"  href = "#srivastava2014"  id = "id1" > < span > [SRIVASTAVA2014]< / span > < / a >  as a way to improve the performance
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								of deep neural networks in low-data regime (i.e. regularization).< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								output has a probability < span  class = "math notranslate nohighlight" > \(p\)< / span >  of being changed to zero and otherwise it is copied from the input.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								This forces the network to perform well even when only < span  class = "math notranslate nohighlight" > \(1 - p\)< / span >  scalars from the input are available.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > At evaluation time we want to use the full power of the network so we set < span  class = "math notranslate nohighlight" > \(p=0\)< / span > . Naively this would
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								in the output softmax temperature). To prevent this we multiply the output by < span  class = "math notranslate nohighlight" > \(\frac{1}{1 - p}\)< / span > , which
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								keeps the norm consistent regardless of the dropout probability.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								< p > Let’  s first take a look at the baseline implementation.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-default notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "kn" > import< / span >  < span  class = "nn" > tabulate< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "kn" > import< / span >  < span  class = "nn" > torch< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "kn" > import< / span >  < span  class = "nn" > triton< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "kn" > import< / span >  < span  class = "nn" > triton.language< / span >  < span  class = "k" > as< / span >  < span  class = "nn" > tl< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nd" > @triton< / span > < span  class = "o" > .< / span > < span  class = "n" > jit< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > _dropout< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > x_ptr< / span > < span  class = "p" > ,< / span >   < span  class = "c1" > # pointer to the input< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > x_keep_ptr< / span > < span  class = "p" > ,< / span >   < span  class = "c1" > # pointer to a mask of 0s and 1s< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > output_ptr< / span > < span  class = "p" > ,< / span >   < span  class = "c1" > # pointer to the output< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > n_elements< / span > < span  class = "p" > ,< / span >   < span  class = "c1" > # number of elements in the `x` tensor< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > p< / span > < span  class = "p" > ,< / span >   < span  class = "c1" > # probability that an element of `x` is changed to zero< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > BLOCK_SIZE< / span > < span  class = "p" > :< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > constexpr< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > pid< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > program_id< / span > < span  class = "p" > (< / span > < span  class = "n" > axis< / span > < span  class = "o" > =< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > block_start< / span >  < span  class = "o" > =< / span >  < span  class = "n" > pid< / span >  < span  class = "o" > *< / span >  < span  class = "n" > BLOCK_SIZE< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > offsets< / span >  < span  class = "o" > =< / span >  < span  class = "n" > block_start< / span >  < span  class = "o" > +< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > arange< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > mask< / span >  < span  class = "o" > =< / span >  < span  class = "n" > offsets< / span >  < span  class = "o" > < < / span >  < span  class = "n" > n_elements< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # Load data< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > x< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > x_ptr< / span >  < span  class = "o" > +< / span >  < span  class = "n" > offsets< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > x_keep< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > x_keep_ptr< / span >  < span  class = "o" > +< / span >  < span  class = "n" > offsets< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # The line below is the crucial part, described in the paragraph above!< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > output< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > where< / span > < span  class = "p" > (< / span > < span  class = "n" > x_keep< / span > < span  class = "p" > ,< / span >  < span  class = "n" > x< / span >  < span  class = "o" > /< / span >  < span  class = "p" > (< / span > < span  class = "mi" > 1< / span >  < span  class = "o" > -< / span >  < span  class = "n" > p< / span > < span  class = "p" > ),< / span >  < span  class = "mf" > 0.0< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # Write-back output< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > store< / span > < span  class = "p" > (< / span > < span  class = "n" > output_ptr< / span >  < span  class = "o" > +< / span >  < span  class = "n" > offsets< / span > < span  class = "p" > ,< / span >  < span  class = "n" > output< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > dropout< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > x_keep< / span > < span  class = "p" > ,< / span >  < span  class = "n" > p< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > output< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > empty_like< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > assert< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > is_contiguous< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > n_elements< / span >  < span  class = "o" > =< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > numel< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > grid< / span >  < span  class = "o" > =< / span >  < span  class = "k" > lambda< / span >  < span  class = "n" > meta< / span > < span  class = "p" > :< / span >  < span  class = "p" > (< / span > < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > cdiv< / span > < span  class = "p" > (< / span > < span  class = "n" > n_elements< / span > < span  class = "p" > ,< / span >  < span  class = "n" > meta< / span > < span  class = "p" > [< / span > < span  class = "s1" > ' BLOCK_SIZE' < / span > < span  class = "p" > ]),)< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > _dropout< / span > < span  class = "p" > [< / span > < span  class = "n" > grid< / span > < span  class = "p" > ](< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > x_keep< / span > < span  class = "p" > ,< / span >  < span  class = "n" > output< / span > < span  class = "p" > ,< / span >  < span  class = "n" > n_elements< / span > < span  class = "p" > ,< / span >  < span  class = "n" > p< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE< / span > < span  class = "o" > =< / span > < span  class = "mi" > 1024< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > return< / span >  < span  class = "n" > output< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # Input tensor< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > x< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > randn< / span > < span  class = "p" > (< / span > < span  class = "n" > size< / span > < span  class = "o" > =< / span > < span  class = "p" > (< / span > < span  class = "mi" > 10< / span > < span  class = "p" > ,))< / span > < span  class = "o" > .< / span > < span  class = "n" > cuda< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # Dropout mask< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > p< / span >  < span  class = "o" > =< / span >  < span  class = "mf" > 0.5< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > x_keep< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > rand< / span > < span  class = "p" > (< / span > < span  class = "n" > size< / span > < span  class = "o" > =< / span > < span  class = "p" > (< / span > < span  class = "mi" > 10< / span > < span  class = "p" > ,))< / span >  < span  class = "o" > > < / span >  < span  class = "n" > p< / span > < span  class = "p" > )< / span > < span  class = "o" > .< / span > < span  class = "n" > to< / span > < span  class = "p" > (< / span > < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > int32< / span > < span  class = "p" > )< / span > < span  class = "o" > .< / span > < span  class = "n" > cuda< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > #< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > output< / span >  < span  class = "o" > =< / span >  < span  class = "n" > dropout< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > x_keep< / span > < span  class = "o" > =< / span > < span  class = "n" > x_keep< / span > < span  class = "p" > ,< / span >  < span  class = "n" > p< / span > < span  class = "o" > =< / span > < span  class = "n" > p< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > print< / span > < span  class = "p" > (< / span > < span  class = "n" > tabulate< / span > < span  class = "o" > .< / span > < span  class = "n" > tabulate< / span > < span  class = "p" > ([< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "p" > [< / span > < span  class = "s2" > " input" < / span > < span  class = "p" > ]< / span >  < span  class = "o" > +< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > tolist< / span > < span  class = "p" > (),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "p" > [< / span > < span  class = "s2" > " keep mask" < / span > < span  class = "p" > ]< / span >  < span  class = "o" > +< / span >  < span  class = "n" > x_keep< / span > < span  class = "o" > .< / span > < span  class = "n" > tolist< / span > < span  class = "p" > (),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "p" > [< / span > < span  class = "s2" > " output" < / span > < span  class = "p" > ]< / span >  < span  class = "o" > +< / span >  < span  class = "n" > output< / span > < span  class = "o" > .< / span > < span  class = "n" > tolist< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > ]))< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "sphx-glr-script-out" > Out:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "sphx-glr-script-out highlight-none notranslate" > < div  class = "highlight" > < pre > < span > < / span > ---------  -------  ---------  --------  --------  --------  --------  --------  --------  ---------  ---------
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								input      1.541    -0.293429  -2.17879  0.568431  -1.08452  -1.3986   0.403347  0.838026  -0.719258  -0.403344
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								keep mask  1         1          0        1          0         1        1         0          0          0
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								output     3.08199  -0.586858   0        1.13686    0        -2.79719  0.806694  0          0          0
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								---------  -------  ---------  --------  --------  --------  --------  --------  --------  ---------  ---------
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "section"  id = "seeded-dropout" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > Seeded dropout< a  class = "headerlink"  href = "#seeded-dropout"  title = "Permalink to this headline" > ¶< / a > < / h2 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								we need to store the dropout mask for backpropagation. Secondly, dropout state management can get
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								very tricky when using recompute/checkpointing (e.g. see all the notes about < cite > preserve_rng_state< / cite >  in
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								< a  class = "reference external"  href = "https://pytorch.org/docs/1.9.0/checkpoint.html" > https://pytorch.org/docs/1.9.0/checkpoint.html< / a > ). In this tutorial we’  ll describe an alternative implementation
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								of persisting randomness across multiple invocations of the kernel.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Pseudorandom number generation in Triton is simple! In this tutorial we will use the
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "code docutils literal notranslate" > < span  class = "pre" > triton.language.rand< / span > < / code >  function which generates a block of uniformly distributed < code  class = "code docutils literal notranslate" > < span  class = "pre" > float32< / span > < / code > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								values in [0, 1), given a seed and a block of < code  class = "code docutils literal notranslate" > < span  class = "pre" > int32< / span > < / code >  offsets. But if you need it, Triton also provides
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								other < a  class = "reference internal"  href = "../../python-api/triton.language.html#random-number-generation" > < span  class = "std std-ref" > random number generation strategies< / span > < / a > .< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "admonition note" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "admonition-title" > Note< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								< p > Triton’  s implementation of PRNG is based on the Philox algorithm (described on < a  class = "reference internal"  href = "#salmon2011"  id = "id2" > < span > [SALMON2011]< / span > < / a > ).< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								< p > Let’  s put it all together.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-default notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "nd" > @triton< / span > < span  class = "o" > .< / span > < span  class = "n" > jit< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > _seeded_dropout< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > x_ptr< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > output_ptr< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > n_elements< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > p< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > seed< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > BLOCK_SIZE< / span > < span  class = "p" > :< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > constexpr< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # compute memory offsets of elements handled by this instance< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > pid< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > program_id< / span > < span  class = "p" > (< / span > < span  class = "n" > axis< / span > < span  class = "o" > =< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > block_start< / span >  < span  class = "o" > =< / span >  < span  class = "n" > pid< / span >  < span  class = "o" > *< / span >  < span  class = "n" > BLOCK_SIZE< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > offsets< / span >  < span  class = "o" > =< / span >  < span  class = "n" > block_start< / span >  < span  class = "o" > +< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > arange< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # load data from x< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > mask< / span >  < span  class = "o" > =< / span >  < span  class = "n" > offsets< / span >  < span  class = "o" > < < / span >  < span  class = "n" > n_elements< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > x< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > x_ptr< / span >  < span  class = "o" > +< / span >  < span  class = "n" > offsets< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # randomly prune it< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > random< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > rand< / span > < span  class = "p" > (< / span > < span  class = "n" > seed< / span > < span  class = "p" > ,< / span >  < span  class = "n" > offsets< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > x_keep< / span >  < span  class = "o" > =< / span >  < span  class = "n" > random< / span >  < span  class = "o" > > < / span >  < span  class = "n" > p< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # write-back< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > output< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > where< / span > < span  class = "p" > (< / span > < span  class = "n" > x_keep< / span > < span  class = "p" > ,< / span >  < span  class = "n" > x< / span >  < span  class = "o" > /< / span >  < span  class = "p" > (< / span > < span  class = "mi" > 1< / span >  < span  class = "o" > -< / span >  < span  class = "n" > p< / span > < span  class = "p" > ),< / span >  < span  class = "mf" > 0.0< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > store< / span > < span  class = "p" > (< / span > < span  class = "n" > output_ptr< / span >  < span  class = "o" > +< / span >  < span  class = "n" > offsets< / span > < span  class = "p" > ,< / span >  < span  class = "n" > output< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > mask< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > seeded_dropout< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > p< / span > < span  class = "p" > ,< / span >  < span  class = "n" > seed< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > output< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > empty_like< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > assert< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > is_contiguous< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > n_elements< / span >  < span  class = "o" > =< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > numel< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > grid< / span >  < span  class = "o" > =< / span >  < span  class = "k" > lambda< / span >  < span  class = "n" > meta< / span > < span  class = "p" > :< / span >  < span  class = "p" > (< / span > < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > cdiv< / span > < span  class = "p" > (< / span > < span  class = "n" > n_elements< / span > < span  class = "p" > ,< / span >  < span  class = "n" > meta< / span > < span  class = "p" > [< / span > < span  class = "s1" > ' BLOCK_SIZE' < / span > < span  class = "p" > ]),)< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > _seeded_dropout< / span > < span  class = "p" > [< / span > < span  class = "n" > grid< / span > < span  class = "p" > ](< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > output< / span > < span  class = "p" > ,< / span >  < span  class = "n" > n_elements< / span > < span  class = "p" > ,< / span >  < span  class = "n" > p< / span > < span  class = "p" > ,< / span >  < span  class = "n" > seed< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE< / span > < span  class = "o" > =< / span > < span  class = "mi" > 1024< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > return< / span >  < span  class = "n" > output< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > x< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > randn< / span > < span  class = "p" > (< / span > < span  class = "n" > size< / span > < span  class = "o" > =< / span > < span  class = "p" > (< / span > < span  class = "mi" > 10< / span > < span  class = "p" > ,))< / span > < span  class = "o" > .< / span > < span  class = "n" > cuda< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # Compare this to the baseline - dropout mask is never instantiated!< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > output< / span >  < span  class = "o" > =< / span >  < span  class = "n" > seeded_dropout< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > p< / span > < span  class = "o" > =< / span > < span  class = "mf" > 0.5< / span > < span  class = "p" > ,< / span >  < span  class = "n" > seed< / span > < span  class = "o" > =< / span > < span  class = "mi" > 123< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > output2< / span >  < span  class = "o" > =< / span >  < span  class = "n" > seeded_dropout< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > p< / span > < span  class = "o" > =< / span > < span  class = "mf" > 0.5< / span > < span  class = "p" > ,< / span >  < span  class = "n" > seed< / span > < span  class = "o" > =< / span > < span  class = "mi" > 123< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > output3< / span >  < span  class = "o" > =< / span >  < span  class = "n" > seeded_dropout< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "n" > p< / span > < span  class = "o" > =< / span > < span  class = "mf" > 0.5< / span > < span  class = "p" > ,< / span >  < span  class = "n" > seed< / span > < span  class = "o" > =< / span > < span  class = "mi" > 512< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > print< / span > < span  class = "p" > (< / span > < span  class = "n" > tabulate< / span > < span  class = "o" > .< / span > < span  class = "n" > tabulate< / span > < span  class = "p" > ([< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "p" > [< / span > < span  class = "s2" > " input" < / span > < span  class = "p" > ]< / span >  < span  class = "o" > +< / span >  < span  class = "n" > x< / span > < span  class = "o" > .< / span > < span  class = "n" > tolist< / span > < span  class = "p" > (),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "p" > [< / span > < span  class = "s2" > " output (seed = 123)" < / span > < span  class = "p" > ]< / span >  < span  class = "o" > +< / span >  < span  class = "n" > output< / span > < span  class = "o" > .< / span > < span  class = "n" > tolist< / span > < span  class = "p" > (),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "p" > [< / span > < span  class = "s2" > " output (seed = 123)" < / span > < span  class = "p" > ]< / span >  < span  class = "o" > +< / span >  < span  class = "n" > output2< / span > < span  class = "o" > .< / span > < span  class = "n" > tolist< / span > < span  class = "p" > (),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "p" > [< / span > < span  class = "s2" > " output (seed = 512)" < / span > < span  class = "p" > ]< / span >  < span  class = "o" > +< / span >  < span  class = "n" > output3< / span > < span  class = "o" > .< / span > < span  class = "n" > tolist< / span > < span  class = "p" > ()< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > ]))< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "sphx-glr-script-out" > Out:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "sphx-glr-script-out highlight-none notranslate" > < div  class = "highlight" > < pre > < span > < / span > -------------------  ---------  --------  --------  -------  --------  --------  ---------  ---------  ---------  ---------
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								input                -0.952835  0.371721  0.408716  1.42142  0.149397  -0.67086  -0.214186  -0.431969  -0.707878  -0.106434
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								output (seed = 123)   0         0.743443  0         0        0         -1.34172   0          0         -1.41576   -0.212868
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								output (seed = 123)   0         0.743443  0         0        0         -1.34172   0          0         -1.41576   -0.212868
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								output (seed = 512)   0         0         0.817432  2.84284  0         -1.34172  -0.428372   0          0          0
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								-------------------  ---------  --------  --------  -------  --------  --------  ---------  ---------  ---------  ---------
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								If you’  d like explore further applications of pseudorandomness in GPU programming, we encourage you
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								to explore the < cite > triton/language/random< / cite >  folder!< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "section"  id = "exercises" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > Exercises< a  class = "headerlink"  href = "#exercises"  title = "Permalink to this headline" > ¶< / a > < / h2 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ol  class = "arabic simple" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Extend the kernel to operate over a matrix and use a vector of seeds - one per row.< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Add support for striding.< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ol > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "section"  id = "references" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > References< a  class = "headerlink"  href = "#references"  title = "Permalink to this headline" > ¶< / a > < / h2 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< dl  class = "citation" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< dt  class = "label"  id = "salmon2011" > < span  class = "brackets" > < a  class = "fn-backref"  href = "#id2" > SALMON2011< / a > < / span > < / dt > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< dd > < p > John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, “Parallel Random Numbers: As Easy as 1, 2, 3”, 2011< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / dd > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< dt  class = "label"  id = "srivastava2014" > < span  class = "brackets" > < a  class = "fn-backref"  href = "#id1" > SRIVASTAVA2014< / a > < / span > < / dt > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< dd > < p > Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, JMLR 2014< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / dd > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / dl > 
							 
						 
					
						
							
								
									
										
										
										
											2022-04-26 00:43:32 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "sphx-glr-timing" > < strong > Total running time of the script:< / strong >  ( 0 minutes  0.325 seconds)< / p > 
							 
						 
					
						
							
								
									
										
										
										
											2022-02-09 07:15:50 +00:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "sphx-glr-footer class sphx-glr-footer-example docutils container"  id = "sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "sphx-glr-download sphx-glr-download-python docutils container" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > < a  class = "reference download internal"  download = ""  href = "../../_downloads/c9aed78977a4c05741d675a38dde3d7d/04-low-memory-dropout.py" > < code  class = "xref download docutils literal notranslate" > < span  class = "pre" > Download< / span >  < span  class = "pre" > Python< / span >  < span  class = "pre" > source< / span >  < span  class = "pre" > code:< / span >  < span  class = "pre" > 04-low-memory-dropout.py< / span > < / code > < / a > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "sphx-glr-download sphx-glr-download-jupyter docutils container" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > < a  class = "reference download internal"  download = ""  href = "../../_downloads/bc847dec325798bdc436c4ef5ac8b78a/04-low-memory-dropout.ipynb" > < code  class = "xref download docutils literal notranslate" > < span  class = "pre" > Download< / span >  < span  class = "pre" > Jupyter< / span >  < span  class = "pre" > notebook:< / span >  < span  class = "pre" > 04-low-memory-dropout.ipynb< / span > < / code > < / a > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "sphx-glr-signature" > < a  class = "reference external"  href = "https://sphinx-gallery.github.io" > Gallery generated by Sphinx-Gallery< / a > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								           < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								           
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < footer > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < div  class = "rst-footer-buttons"  role = "navigation"  aria-label = "footer navigation" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < a  href = "05-layer-norm.html"  class = "btn btn-neutral float-right"  title = "Layer Normalization"  accesskey = "n"  rel = "next" > Next < span  class = "fa fa-arrow-circle-right"  aria-hidden = "true" > < / span > < / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < a  href = "03-matrix-multiplication.html"  class = "btn btn-neutral float-left"  title = "Matrix Multiplication"  accesskey = "p"  rel = "prev" > < span  class = "fa fa-arrow-circle-left"  aria-hidden = "true" > < / span >  Previous< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < hr / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  role = "contentinfo" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        ©  Copyright 2020, Philippe Tillet.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    Built with < a  href = "https://www.sphinx-doc.org/" > Sphinx< / a >  using a
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < a  href = "https://github.com/readthedocs/sphinx_rtd_theme" > theme< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    provided by < a  href = "https://readthedocs.org" > Read the Docs< / a > . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / footer > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "rst-versions"  data-toggle = "rst-versions"  role = "note"  aria-label = "versions" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "rst-current-version"  data-toggle = "rst-current-version" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "fa fa-book" >  Other Versions< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        v: master
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "fa fa-caret-down" > < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < div  class = "rst-other-versions" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < dl > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < dt > Tags< / dt > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < dd > < a  href = "../../../v1.1.2/index.html" > v1.1.2< / a > < / dd > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < / dl > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < dl > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < dt > Branches< / dt > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < dd > < a  href = "04-low-memory-dropout.html" > master< / a > < / dd > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < / dl > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < script  type = "text/javascript" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      jQuery(function () {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          SphinxRtdTheme.Navigation.enable(true);
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      });
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								   
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / body > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / html >