2022-02-09 07:15:50 +00:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								<!DOCTYPE html> 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< html  class = "writer-html5"  lang = "en"  > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< head > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < meta  charset = "utf-8"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < meta  name = "viewport"  content = "width=device-width, initial-scale=1.0"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < title > Matrix Multiplication —  Triton  documentation< / title > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/css/theme.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/pygments.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/pygments.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/css/theme.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/gallery.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/gallery-binder.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/gallery-dataframe.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/gallery-rendered-html.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < link  rel = "stylesheet"  href = "../../_static/css/custom.css"  type = "text/css"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  <!-- [if lt IE 9]>
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < script  src = "../../_static/js/html5shiv.min.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  <![endif]--> 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < script  type = "text/javascript"  id = "documentation_options"  data-url_root = "../../"  src = "../../_static/documentation_options.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < script  data-url_root = "../../"  id = "documentation_options"  src = "../../_static/documentation_options.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < script  src = "../../_static/jquery.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < script  src = "../../_static/underscore.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < script  src = "../../_static/doctools.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < script  type = "text/javascript"  src = "../../_static/js/theme.js" > < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < link  rel = "index"  title = "Index"  href = "../../genindex.html"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < link  rel = "search"  title = "Search"  href = "../../search.html"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < link  rel = "next"  title = "Low-Memory Dropout"  href = "04-low-memory-dropout.html"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < link  rel = "prev"  title = "Fused Softmax"  href = "02-fused-softmax.html"  / >  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / head > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< body  class = "wy-body-for-nav" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								   
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  class = "wy-grid-for-nav" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < nav  data-toggle = "wy-nav-shift"  class = "wy-nav-side" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < div  class = "wy-side-scroll" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < div  class = "wy-side-nav-search"  > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < a  href = "../../index.html"  class = "icon icon-home" >  Triton
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  role = "search" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < form  id = "rtd-search-form"  class = "wy-form"  action = "../../search.html"  method = "get" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < input  type = "text"  name = "q"  placeholder = "Search docs"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < input  type = "hidden"  name = "check_keywords"  value = "yes"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < input  type = "hidden"  name = "area"  value = "default"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / form > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < div  class = "wy-menu wy-menu-vertical"  data-spy = "affix"  role = "navigation"  aria-label = "main navigation" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								              < p  class = "caption"  role = "heading" > < span  class = "caption-text" > Getting Started< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul  class = "current" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../installation.html" > Installation< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1 current" > < a  class = "reference internal"  href = "index.html" > Tutorials< / a > < ul  class = "current" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "01-vector-add.html" > Vector Addition< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "02-fused-softmax.html" > Fused Softmax< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2 current" > < a  class = "current reference internal"  href = "#" > Matrix Multiplication< / a > < ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "#motivations" > Motivations< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "#compute-kernel" > Compute Kernel< / a > < ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l4" > < a  class = "reference internal"  href = "#pointer-arithmetics" > Pointer Arithmetics< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l4" > < a  class = "reference internal"  href = "#l2-cache-optimizations" > L2 Cache Optimizations< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "#final-result" > Final Result< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "#unit-test" > Unit Test< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l3" > < a  class = "reference internal"  href = "#benchmark" > Benchmark< / a > < ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l4" > < a  class = "reference internal"  href = "#square-matrix-performance" > Square Matrix Performance< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "04-low-memory-dropout.html" > Low-Memory Dropout< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l2" > < a  class = "reference internal"  href = "05-layer-norm.html" > Layer Normalization< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "caption"  role = "heading" > < span  class = "caption-text" > Python API< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../python-api/triton.html" > triton< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../python-api/triton.language.html" > triton.language< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../python-api/triton.testing.html" > triton.testing< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "caption"  role = "heading" > < span  class = "caption-text" > Programming Guide< / span > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../programming-guide/chapter-1/introduction.html" > Introduction< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li  class = "toctree-l1" > < a  class = "reference internal"  href = "../../programming-guide/chapter-2/related-work.html" > Related Work< / a > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / nav > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < section  data-toggle = "wy-nav-shift"  class = "wy-nav-content-wrap" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < nav  class = "wy-nav-top"  aria-label = "top navigation" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < i  data-toggle = "wy-nav-top"  class = "fa fa-bars" > < / i > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < a  href = "../../index.html" > Triton< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < / nav > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < div  class = "wy-nav-content" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < div  class = "rst-content" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  role = "navigation"  aria-label = "breadcrumbs navigation" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < ul  class = "wy-breadcrumbs" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < li > < a  href = "../../index.html"  class = "icon icon-home" > < / a >  » < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < li > < a  href = "index.html" > Tutorials< / a >  » < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < li > Matrix Multiplication< / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < li  class = "wy-breadcrumbs-aside" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < a  href = "../../_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt"  rel = "nofollow" >  View page source< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < hr / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < div  role = "main"  class = "document"  itemscope = "itemscope"  itemtype = "http://schema.org/Article" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								           < div  itemprop = "articleBody" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  class = "sphx-glr-download-link-note admonition note" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "admonition-title" > Note< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Click < a  class = "reference internal"  href = "#sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py" > < span  class = "std std-ref" > here< / span > < / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								to download the full example code< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "sphx-glr-example-title section"  id = "matrix-multiplication" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  id = "sphx-glr-getting-started-tutorials-03-matrix-multiplication-py" > < / span > < h1 > Matrix Multiplication< a  class = "headerlink"  href = "#matrix-multiplication"  title = "Permalink to this headline" > ¶< / a > < / h1 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								kernel that achieves performance on par with cuBLAS.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								You will specifically learn about:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< ul  class = "simple" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Block-level matrix multiplications< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Multi-dimensional pointer arithmetic< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Program re-ordering for improved L2 cache hit rate< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< li > < p > Automatic performance tuning< / p > < / li > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / ul > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "section"  id = "motivations" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > Motivations< a  class = "headerlink"  href = "#motivations"  title = "Permalink to this headline" > ¶< / a > < / h2 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Matrix multiplications are a key building block of most modern high-performance computing systems.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								They are notoriously hard to optimize, hence their implementation is generally done by
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								hardware vendors themselves as part of so-called “kernel libraries” (e.g., cuBLAS).
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								Unfortunately, these libraries are often proprietary and cannot be easily customized
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								In this tutorial, you will learn how to implement efficient matrix multiplications by
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								yourself with Triton, in a way that is easy to customize and extend.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Roughly speaking, the kernel that we will write will implement the following blocked
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								algorithm to multiply a (M, K) by a (K, N) matrix:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< blockquote > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div > < div  class = "highlight-python notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "c1" > # do in parallel< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "k" > for< / span >  < span  class = "n" > m< / span >  < span  class = "ow" > in< / span >  < span  class = "nb" > range< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < span  class = "c1" > # do in parallel< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < span  class = "k" > for< / span >  < span  class = "n" > n< / span >  < span  class = "ow" > in< / span >  < span  class = "nb" > range< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > acc< / span >  < span  class = "o" > =< / span >  < span  class = "n" > zeros< / span > < span  class = "p" > ((< / span > < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > ),< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > float32< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > for< / span >  < span  class = "n" > k< / span >  < span  class = "ow" > in< / span >  < span  class = "nb" > range< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > K< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_K< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < span  class = "n" > a< / span >  < span  class = "o" > =< / span >  < span  class = "n" > A< / span > < span  class = "p" > [< / span > < span  class = "n" > m< / span >  < span  class = "p" > :< / span >  < span  class = "n" > m< / span > < span  class = "o" > +< / span > < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > k< / span >  < span  class = "p" > :< / span >  < span  class = "n" > k< / span > < span  class = "o" > +< / span > < span  class = "n" > BLOCK_SIZE_K< / span > < span  class = "p" > ]< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < span  class = "n" > b< / span >  < span  class = "o" > =< / span >  < span  class = "n" > B< / span > < span  class = "p" > [< / span > < span  class = "n" > k< / span >  < span  class = "p" > :< / span >  < span  class = "n" > k< / span > < span  class = "o" > +< / span > < span  class = "n" > BLOCK_SIZE_K< / span > < span  class = "p" > ,< / span >  < span  class = "n" > n< / span >  < span  class = "p" > :< / span >  < span  class = "n" > n< / span > < span  class = "o" > +< / span > < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > ]< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < span  class = "n" > acc< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > dot< / span > < span  class = "p" > (< / span > < span  class = "n" > a< / span > < span  class = "p" > ,< / span >  < span  class = "n" > b< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > C< / span > < span  class = "p" > [< / span > < span  class = "n" > m< / span >  < span  class = "p" > :< / span >  < span  class = "n" > m< / span > < span  class = "o" > +< / span > < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > n< / span >  < span  class = "p" > :< / span >  < span  class = "n" > n< / span > < span  class = "o" > +< / span > < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > ]< / span >  < span  class = "o" > =< / span >  < span  class = "n" > acc< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > < / blockquote > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "section"  id = "compute-kernel" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > Compute Kernel< a  class = "headerlink"  href = "#compute-kernel"  title = "Permalink to this headline" > ¶< / a > < / h2 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > The above algorithm is, actually, fairly straightforward to implement in Triton.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								The main difficulty comes from the computation of the memory locations at which blocks
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								of < code  class = "code docutils literal notranslate" > < span  class = "pre" > A< / span > < / code >  and < code  class = "code docutils literal notranslate" > < span  class = "pre" > B< / span > < / code >  must be read in the inner loop. For that, we need
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								multi-dimensional pointer arithmetics.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "section"  id = "pointer-arithmetics" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h3 > Pointer Arithmetics< a  class = "headerlink"  href = "#pointer-arithmetics"  title = "Permalink to this headline" > ¶< / a > < / h3 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > For a row-major 2D tensor < code  class = "code docutils literal notranslate" > < span  class = "pre" > X< / span > < / code > , the memory location of < code  class = "code docutils literal notranslate" > < span  class = "pre" > X[i,< / span >  < span  class = "pre" > j]< / span > < / code >  is given b
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								y < code  class = "code docutils literal notranslate" > < span  class = "pre" > & X[i,< / span >  < span  class = "pre" > j]< / span >  < span  class = "pre" > =< / span >  < span  class = "pre" > X< / span >  < span  class = "pre" > +< / span >  < span  class = "pre" > i*stride_xi< / span >  < span  class = "pre" > +< / span >  < span  class = "pre" > j*stride_xj< / span > < / code > .
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								Therefore, blocks of pointers for < code  class = "code docutils literal notranslate" > < span  class = "pre" > A[m< / span >  < span  class = "pre" > :< / span >  < span  class = "pre" > m+BLOCK_SIZE_M,< / span >  < span  class = "pre" > k:k+BLOCK_SIZE_K]< / span > < / code >  and
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< code  class = "code docutils literal notranslate" > < span  class = "pre" > B[k< / span >  < span  class = "pre" > :< / span >  < span  class = "pre" > k+BLOCK_SIZE_K,< / span >  < span  class = "pre" > n< / span >  < span  class = "pre" > :< / span >  < span  class = "pre" > n+BLOCK_SIZE_N]< / span > < / code >  can be defined in pseudo-code as:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< blockquote > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div > < div  class = "highlight-python notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "o" > & < / span > < span  class = "n" > A< / span > < span  class = "p" > [< / span > < span  class = "n" > m< / span >  < span  class = "p" > :< / span >  < span  class = "n" > m< / span > < span  class = "o" > +< / span > < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > k< / span > < span  class = "p" > :< / span > < span  class = "n" > k< / span > < span  class = "o" > +< / span > < span  class = "n" > BLOCK_SIZE_K< / span > < span  class = "p" > ]< / span >  < span  class = "o" > =< / span >   < span  class = "n" > a_ptr< / span >  < span  class = "o" > +< / span >  < span  class = "p" > (< / span > < span  class = "n" > m< / span >  < span  class = "p" > :< / span >  < span  class = "n" > m< / span > < span  class = "o" > +< / span > < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > )[:,< / span >  < span  class = "kc" > None< / span > < span  class = "p" > ]< / span > < span  class = "o" > *< / span > < span  class = "n" > A< / span > < span  class = "o" > .< / span > < span  class = "n" > stride< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span >  < span  class = "o" > +< / span >  < span  class = "p" > (< / span > < span  class = "n" > k< / span >  < span  class = "p" > :< / span >  < span  class = "n" > k< / span > < span  class = "o" > +< / span > < span  class = "n" > BLOCK_SIZE_K< / span > < span  class = "p" > )[< / span > < span  class = "kc" > None< / span > < span  class = "p" > ,< / span >  < span  class = "p" > :]< / span > < span  class = "o" > *< / span > < span  class = "n" > A< / span > < span  class = "o" > .< / span > < span  class = "n" > stride< / span > < span  class = "p" > (< / span > < span  class = "mi" > 1< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "o" > & < / span > < span  class = "n" > B< / span > < span  class = "p" > [< / span > < span  class = "n" > k< / span >  < span  class = "p" > :< / span >  < span  class = "n" > k< / span > < span  class = "o" > +< / span > < span  class = "n" > BLOCK_SIZE_K< / span > < span  class = "p" > ,< / span >  < span  class = "n" > n< / span > < span  class = "p" > :< / span > < span  class = "n" > n< / span > < span  class = "o" > +< / span > < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > ]< / span >  < span  class = "o" > =< / span >   < span  class = "n" > b_ptr< / span >  < span  class = "o" > +< / span >  < span  class = "p" > (< / span > < span  class = "n" > k< / span >  < span  class = "p" > :< / span >  < span  class = "n" > k< / span > < span  class = "o" > +< / span > < span  class = "n" > BLOCK_SIZE_K< / span > < span  class = "p" > )[:,< / span >  < span  class = "kc" > None< / span > < span  class = "p" > ]< / span > < span  class = "o" > *< / span > < span  class = "n" > B< / span > < span  class = "o" > .< / span > < span  class = "n" > stride< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span >  < span  class = "o" > +< / span >  < span  class = "p" > (< / span > < span  class = "n" > n< / span >  < span  class = "p" > :< / span >  < span  class = "n" > n< / span > < span  class = "o" > +< / span > < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > )[< / span > < span  class = "kc" > None< / span > < span  class = "p" > ,< / span >  < span  class = "p" > :]< / span > < span  class = "o" > *< / span > < span  class = "n" > B< / span > < span  class = "o" > .< / span > < span  class = "n" > stride< / span > < span  class = "p" > (< / span > < span  class = "mi" > 1< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > < / blockquote > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > Which means that pointers for blocks of A and B can be initialized (i.e., < code  class = "code docutils literal notranslate" > < span  class = "pre" > k=0< / span > < / code > ) in Triton as:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< blockquote > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div > < div  class = "highlight-python notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "n" > offs_am< / span >  < span  class = "o" > =< / span >  < span  class = "n" > pid_m< / span >  < span  class = "o" > *< / span >  < span  class = "n" > BLOCK_SIZE_M< / span >  < span  class = "o" > +< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > arange< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > offs_bn< / span >  < span  class = "o" > =< / span >  < span  class = "n" > pid_n< / span >  < span  class = "o" > *< / span >  < span  class = "n" > BLOCK_SIZE_N< / span >  < span  class = "o" > +< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > arange< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > offs_k< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > arange< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_K< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > a_ptrs< / span >  < span  class = "o" > =< / span >  < span  class = "n" > a_ptr< / span >  < span  class = "o" > +< / span >  < span  class = "p" > (< / span > < span  class = "n" > offs_am< / span > < span  class = "p" > [:,< / span >  < span  class = "kc" > None< / span > < span  class = "p" > ]< / span > < span  class = "o" > *< / span > < span  class = "n" > stride_am< / span >  < span  class = "o" > +< / span >  < span  class = "n" > offs_k< / span >  < span  class = "p" > [< / span > < span  class = "kc" > None< / span > < span  class = "p" > ,< / span >  < span  class = "p" > :]< / span > < span  class = "o" > *< / span > < span  class = "n" > stride_ak< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > b_ptrs< / span >  < span  class = "o" > =< / span >  < span  class = "n" > b_ptr< / span >  < span  class = "o" > +< / span >  < span  class = "p" > (< / span > < span  class = "n" > offs_k< / span >  < span  class = "p" > [:,< / span >  < span  class = "kc" > None< / span > < span  class = "p" > ]< / span > < span  class = "o" > *< / span > < span  class = "n" > stride_bk< / span >  < span  class = "o" > +< / span >  < span  class = "n" > offs_bn< / span > < span  class = "p" > [< / span > < span  class = "kc" > None< / span > < span  class = "p" > ,< / span >  < span  class = "p" > :]< / span > < span  class = "o" > *< / span > < span  class = "n" > stride_bn< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > < / blockquote > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > And then updated in the inner loop as follows:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< blockquote > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div > < div  class = "highlight-python notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "n" > pa< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > BLOCK_SIZE_K< / span >  < span  class = "o" > *< / span >  < span  class = "n" > stride_ak< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > pb< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > BLOCK_SIZE_K< / span >  < span  class = "o" > *< / span >  < span  class = "n" > stride_bk< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > < / blockquote > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "section"  id = "l2-cache-optimizations" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h3 > L2 Cache Optimizations< a  class = "headerlink"  href = "#l2-cache-optimizations"  title = "Permalink to this headline" > ¶< / a > < / h3 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > As mentioned above, each program instance computes a < code  class = "code docutils literal notranslate" > < span  class = "pre" > [BLOCK_SIZE_M,< / span >  < span  class = "pre" > BLOCK_SIZE_N]< / span > < / code > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								block of < code  class = "code docutils literal notranslate" > < span  class = "pre" > C< / span > < / code > .
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								It is important to remember that the order in which these blocks are computed does
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								a simple row-major ordering< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< blockquote > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div > < div  class = "highlight-Python notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "n" > pid< / span >  < span  class = "o" > =< / span >  < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > program_id< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > );< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > grid_m< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > M< / span >  < span  class = "o" > +< / span >  < span  class = "n" > BLOCK_SIZE_M< / span >  < span  class = "o" > -< / span >  < span  class = "mi" > 1< / span > < span  class = "p" > )< / span >  < span  class = "o" > //< / span >  < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > grid_n< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > N< / span >  < span  class = "o" > +< / span >  < span  class = "n" > BLOCK_SIZE_N< / span >  < span  class = "o" > -< / span >  < span  class = "mi" > 1< / span > < span  class = "p" > )< / span >  < span  class = "o" > //< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > pid_m< / span >  < span  class = "o" > =< / span >  < span  class = "n" > pid< / span >  < span  class = "o" > /< / span >  < span  class = "n" > grid_n< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > pid_n< / span >  < span  class = "o" > =< / span >  < span  class = "n" > pid< / span >  < span  class = "o" > %< / span >  < span  class = "n" > grid_n< / span > < span  class = "p" > ;< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > < / blockquote > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > is just not going to cut it.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > One possible solution is to launch blocks in an order that promotes data reuse.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
										 
									
								 
							
							
								This can be done by ‘  super-grouping’   blocks in groups of < code  class = "code docutils literal notranslate" > < span  class = "pre" > GROUP_M< / span > < / code >  rows before
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								switching to the next column:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< blockquote > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div > < div  class = "highlight-python notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "c1" > # program ID< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > pid< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > program_id< / span > < span  class = "p" > (< / span > < span  class = "n" > axis< / span > < span  class = "o" > =< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # number of program ids along the M axis< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > num_pid_m< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > cdiv< / span > < span  class = "p" > (< / span > < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # number of programs ids along the N axis< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > num_pid_n< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > cdiv< / span > < span  class = "p" > (< / span > < span  class = "n" > N< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # number of programs in group< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > num_pid_in_group< / span >  < span  class = "o" > =< / span >  < span  class = "n" > GROUP_SIZE_M< / span >  < span  class = "o" > *< / span >  < span  class = "n" > num_pid_n< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # id of the group this program is in< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > group_id< / span >  < span  class = "o" > =< / span >  < span  class = "n" > pid< / span >  < span  class = "o" > //< / span >  < span  class = "n" > num_pid_in_group< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # row-id of the first program in the group< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > first_pid_m< / span >  < span  class = "o" > =< / span >  < span  class = "n" > group_id< / span >  < span  class = "o" > *< / span >  < span  class = "n" > GROUP_SIZE_M< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # if `num_pid_m` isn' t divisible by `GROUP_SIZE_M`, the last group is smaller< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > group_size_m< / span >  < span  class = "o" > =< / span >  < span  class = "nb" > min< / span > < span  class = "p" > (< / span > < span  class = "n" > num_pid_m< / span >  < span  class = "o" > -< / span >  < span  class = "n" > first_pid_m< / span > < span  class = "p" > ,< / span >  < span  class = "n" > GROUP_SIZE_M< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # *within groups*, programs are ordered in a column-major order< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # row-id of the program in the *launch grid*< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > pid_m< / span >  < span  class = "o" > =< / span >  < span  class = "n" > first_pid_m< / span >  < span  class = "o" > +< / span >  < span  class = "p" > (< / span > < span  class = "n" > pid< / span >  < span  class = "o" > %< / span >  < span  class = "n" > group_size_m< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # col-id of the program in the *launch grid*< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > pid_n< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > pid< / span >  < span  class = "o" > %< / span >  < span  class = "n" > num_pid_in_group< / span > < span  class = "p" > )< / span >  < span  class = "o" > //< / span >  < span  class = "n" > group_size_m< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > < / blockquote > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								we can see that if we compute the output in row-major ordering, we need to load 90
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								ordering, we only need to load 54 blocks.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< blockquote > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div > < img  alt = "../../_images/grouped_vs_row_major_ordering.png"  src = "../../_images/grouped_vs_row_major_ordering.png"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > < / blockquote > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > In practice, this can improve the performance of our matrix multiplication kernel by
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "section"  id = "final-result" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > Final Result< a  class = "headerlink"  href = "#final-result"  title = "Permalink to this headline" > ¶< / a > < / h2 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-default notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "kn" > import< / span >  < span  class = "nn" > torch< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "kn" > import< / span >  < span  class = "nn" > triton< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "kn" > import< / span >  < span  class = "nn" > triton.language< / span >  < span  class = "k" > as< / span >  < span  class = "nn" > tl< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # %< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # :code:`triton.jit`' ed functions can be auto-tuned by using the `triton.autotune`< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # decorator, which consumes:< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > #   - A list of :code:`triton.Config` objects that define different configurations of< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > #       meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > #   - An autotuning *key* whose change in values will trigger evaluation of all the< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > #       provided configs< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nd" > @triton< / span > < span  class = "o" > .< / span > < span  class = "n" > autotune< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > configs< / span > < span  class = "o" > =< / span > < span  class = "p" > [< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > Config< / span > < span  class = "p" > ({< / span > < span  class = "s1" > ' BLOCK_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 128< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_N' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 256< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_K' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 32< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' GROUP_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 8< / span > < span  class = "p" > },< / span >  < span  class = "n" > num_stages< / span > < span  class = "o" > =< / span > < span  class = "mi" > 3< / span > < span  class = "p" > ,< / span >  < span  class = "n" > num_warps< / span > < span  class = "o" > =< / span > < span  class = "mi" > 8< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > Config< / span > < span  class = "p" > ({< / span > < span  class = "s1" > ' BLOCK_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 256< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_N' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 128< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_K' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 32< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' GROUP_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 8< / span > < span  class = "p" > },< / span >  < span  class = "n" > num_stages< / span > < span  class = "o" > =< / span > < span  class = "mi" > 3< / span > < span  class = "p" > ,< / span >  < span  class = "n" > num_warps< / span > < span  class = "o" > =< / span > < span  class = "mi" > 8< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > Config< / span > < span  class = "p" > ({< / span > < span  class = "s1" > ' BLOCK_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 256< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_N' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 64< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_K' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 32< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' GROUP_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 8< / span > < span  class = "p" > },< / span >  < span  class = "n" > num_stages< / span > < span  class = "o" > =< / span > < span  class = "mi" > 4< / span > < span  class = "p" > ,< / span >  < span  class = "n" > num_warps< / span > < span  class = "o" > =< / span > < span  class = "mi" > 4< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > Config< / span > < span  class = "p" > ({< / span > < span  class = "s1" > ' BLOCK_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 64< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_N' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 256< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_K' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 32< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' GROUP_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 8< / span > < span  class = "p" > },< / span >  < span  class = "n" > num_stages< / span > < span  class = "o" > =< / span > < span  class = "mi" > 4< / span > < span  class = "p" > ,< / span >  < span  class = "n" > num_warps< / span > < span  class = "o" > =< / span > < span  class = "mi" > 4< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > Config< / span > < span  class = "p" > ({< / span > < span  class = "s1" > ' BLOCK_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 128< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_N' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 128< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_K' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 32< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' GROUP_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 8< / span > < span  class = "p" > },< / span >  < span  class = "n" > num_stages< / span > < span  class = "o" > =< / span > < span  class = "mi" > 4< / span > < span  class = "p" > ,< / span >  < span  class = "n" > num_warps< / span > < span  class = "o" > =< / span > < span  class = "mi" > 4< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > Config< / span > < span  class = "p" > ({< / span > < span  class = "s1" > ' BLOCK_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 128< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_N' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 64< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_K' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 32< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' GROUP_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 8< / span > < span  class = "p" > },< / span >  < span  class = "n" > num_stages< / span > < span  class = "o" > =< / span > < span  class = "mi" > 4< / span > < span  class = "p" > ,< / span >  < span  class = "n" > num_warps< / span > < span  class = "o" > =< / span > < span  class = "mi" > 4< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > Config< / span > < span  class = "p" > ({< / span > < span  class = "s1" > ' BLOCK_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 64< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_N' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 128< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_K' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 32< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' GROUP_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 8< / span > < span  class = "p" > },< / span >  < span  class = "n" > num_stages< / span > < span  class = "o" > =< / span > < span  class = "mi" > 4< / span > < span  class = "p" > ,< / span >  < span  class = "n" > num_warps< / span > < span  class = "o" > =< / span > < span  class = "mi" > 4< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > Config< / span > < span  class = "p" > ({< / span > < span  class = "s1" > ' BLOCK_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 128< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_N' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 32< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_K' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 32< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' GROUP_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 8< / span > < span  class = "p" > },< / span >  < span  class = "n" > num_stages< / span > < span  class = "o" > =< / span > < span  class = "mi" > 4< / span > < span  class = "p" > ,< / span >  < span  class = "n" > num_warps< / span > < span  class = "o" > =< / span > < span  class = "mi" > 4< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > Config< / span > < span  class = "p" > ({< / span > < span  class = "s1" > ' BLOCK_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 64< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_N' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 32< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_K' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 32< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' GROUP_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 8< / span > < span  class = "p" > },< / span >  < span  class = "n" > num_stages< / span > < span  class = "o" > =< / span > < span  class = "mi" > 5< / span > < span  class = "p" > ,< / span >  < span  class = "n" > num_warps< / span > < span  class = "o" > =< / span > < span  class = "mi" > 2< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > Config< / span > < span  class = "p" > ({< / span > < span  class = "s1" > ' BLOCK_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 32< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_N' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 64< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' BLOCK_SIZE_K' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 32< / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' GROUP_SIZE_M' < / span > < span  class = "p" > :< / span >  < span  class = "mi" > 8< / span > < span  class = "p" > },< / span >  < span  class = "n" > num_stages< / span > < span  class = "o" > =< / span > < span  class = "mi" > 5< / span > < span  class = "p" > ,< / span >  < span  class = "n" > num_warps< / span > < span  class = "o" > =< / span > < span  class = "mi" > 2< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "p" > ],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > key< / span > < span  class = "o" > =< / span > < span  class = "p" > [< / span > < span  class = "s1" > ' M' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' N' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' K' < / span > < span  class = "p" > ],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nd" > @triton< / span > < span  class = "o" > .< / span > < span  class = "n" > jit< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > matmul_kernel< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # Pointers to matrices< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > a_ptr< / span > < span  class = "p" > ,< / span >  < span  class = "n" > b_ptr< / span > < span  class = "p" > ,< / span >  < span  class = "n" > c_ptr< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # Matrix dimensions< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > ,< / span >  < span  class = "n" > K< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # The stride variables represent how much to increase the ptr by when moving by 1< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # element in a particular dimension. E.g. stride_am is how much to increase a_ptr< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # by to get the element one row down (A has M rows)< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > stride_am< / span > < span  class = "p" > ,< / span >  < span  class = "n" > stride_ak< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > stride_bk< / span > < span  class = "p" > ,< / span >  < span  class = "n" > stride_bn< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > stride_cm< / span > < span  class = "p" > ,< / span >  < span  class = "n" > stride_cn< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # Meta-parameters< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > :< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > constexpr< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > :< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > constexpr< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_K< / span > < span  class = "p" > :< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > constexpr< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > GROUP_SIZE_M< / span > < span  class = "p" > :< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > constexpr< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > ACTIVATION< / span > < span  class = "p" > :< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > constexpr< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "sd" > " " " Kernel for computing the matmul C = A x B.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "sd" >     A has shape (M, K), B has shape (K, N) and C has shape (M, N)< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "sd" >     " " " < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # -----------------------------------------------------------< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # Map program ids `pid` to the block of C it should compute.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # This is done in a grouped ordering to promote L2 data reuse< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # See above `L2 Cache Optimizations` section for details< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > pid< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > program_id< / span > < span  class = "p" > (< / span > < span  class = "n" > axis< / span > < span  class = "o" > =< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > num_pid_m< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > cdiv< / span > < span  class = "p" > (< / span > < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > num_pid_n< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > cdiv< / span > < span  class = "p" > (< / span > < span  class = "n" > N< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > num_pid_in_group< / span >  < span  class = "o" > =< / span >  < span  class = "n" > GROUP_SIZE_M< / span >  < span  class = "o" > *< / span >  < span  class = "n" > num_pid_n< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > group_id< / span >  < span  class = "o" > =< / span >  < span  class = "n" > pid< / span >  < span  class = "o" > //< / span >  < span  class = "n" > num_pid_in_group< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > first_pid_m< / span >  < span  class = "o" > =< / span >  < span  class = "n" > group_id< / span >  < span  class = "o" > *< / span >  < span  class = "n" > GROUP_SIZE_M< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > group_size_m< / span >  < span  class = "o" > =< / span >  < span  class = "nb" > min< / span > < span  class = "p" > (< / span > < span  class = "n" > num_pid_m< / span >  < span  class = "o" > -< / span >  < span  class = "n" > first_pid_m< / span > < span  class = "p" > ,< / span >  < span  class = "n" > GROUP_SIZE_M< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > pid_m< / span >  < span  class = "o" > =< / span >  < span  class = "n" > first_pid_m< / span >  < span  class = "o" > +< / span >  < span  class = "p" > (< / span > < span  class = "n" > pid< / span >  < span  class = "o" > %< / span >  < span  class = "n" > group_size_m< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > pid_n< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > pid< / span >  < span  class = "o" > %< / span >  < span  class = "n" > num_pid_in_group< / span > < span  class = "p" > )< / span >  < span  class = "o" > //< / span >  < span  class = "n" > group_size_m< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # ----------------------------------------------------------< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # Create pointers for the first blocks of A and B.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # We will advance this pointer as we move in the K direction< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # and accumulate< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # see above `Pointer Arithmetics` section for details< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > offs_am< / span >  < span  class = "o" > =< / span >  < span  class = "n" > pid_m< / span >  < span  class = "o" > *< / span >  < span  class = "n" > BLOCK_SIZE_M< / span >  < span  class = "o" > +< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > arange< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > offs_bn< / span >  < span  class = "o" > =< / span >  < span  class = "n" > pid_n< / span >  < span  class = "o" > *< / span >  < span  class = "n" > BLOCK_SIZE_N< / span >  < span  class = "o" > +< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > arange< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > offs_k< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > arange< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_K< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > a_ptrs< / span >  < span  class = "o" > =< / span >  < span  class = "n" > a_ptr< / span >  < span  class = "o" > +< / span >  < span  class = "p" > (< / span > < span  class = "n" > offs_am< / span > < span  class = "p" > [:,< / span >  < span  class = "kc" > None< / span > < span  class = "p" > ]< / span >  < span  class = "o" > *< / span >  < span  class = "n" > stride_am< / span >  < span  class = "o" > +< / span >  < span  class = "n" > offs_k< / span > < span  class = "p" > [< / span > < span  class = "kc" > None< / span > < span  class = "p" > ,< / span >  < span  class = "p" > :]< / span >  < span  class = "o" > *< / span >  < span  class = "n" > stride_ak< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > b_ptrs< / span >  < span  class = "o" > =< / span >  < span  class = "n" > b_ptr< / span >  < span  class = "o" > +< / span >  < span  class = "p" > (< / span > < span  class = "n" > offs_k< / span > < span  class = "p" > [:,< / span >  < span  class = "kc" > None< / span > < span  class = "p" > ]< / span >  < span  class = "o" > *< / span >  < span  class = "n" > stride_bk< / span >  < span  class = "o" > +< / span >  < span  class = "n" > offs_bn< / span > < span  class = "p" > [< / span > < span  class = "kc" > None< / span > < span  class = "p" > ,< / span >  < span  class = "p" > :]< / span >  < span  class = "o" > *< / span >  < span  class = "n" > stride_bn< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # -----------------------------------------------------------< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # Iterate to compute a block of the C matrix< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # of fp32 values for higher accuracy.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # `accumulator` will be converted back to fp16 after the loop< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > accumulator< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > zeros< / span > < span  class = "p" > ((< / span > < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > ),< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > float32< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > for< / span >  < span  class = "n" > k< / span >  < span  class = "ow" > in< / span >  < span  class = "nb" > range< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > K< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_K< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "c1" > # Note that for simplicity, we don' t apply a mask here.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "c1" > # This means that if K is not a multiple of BLOCK_SIZE_K,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "c1" > # this will access out-of-bounds memory and produce an< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "c1" > # error or (worse!) incorrect results.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > a< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > a_ptrs< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > b< / span >  < span  class = "o" > =< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > load< / span > < span  class = "p" > (< / span > < span  class = "n" > b_ptrs< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "c1" > # We accumulate along the K dimension< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > accumulator< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > dot< / span > < span  class = "p" > (< / span > < span  class = "n" > a< / span > < span  class = "p" > ,< / span >  < span  class = "n" > b< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "c1" > # Advance the ptrs to the next K block< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > a_ptrs< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > BLOCK_SIZE_K< / span >  < span  class = "o" > *< / span >  < span  class = "n" > stride_ak< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > b_ptrs< / span >  < span  class = "o" > +=< / span >  < span  class = "n" > BLOCK_SIZE_K< / span >  < span  class = "o" > *< / span >  < span  class = "n" > stride_bk< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # you can fuse arbitrary activation functions here< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # while the accumulator is still in FP32!< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > if< / span >  < span  class = "n" > ACTIVATION< / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > accumulator< / span >  < span  class = "o" > =< / span >  < span  class = "n" > ACTIVATION< / span > < span  class = "p" > (< / span > < span  class = "n" > accumulator< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > c< / span >  < span  class = "o" > =< / span >  < span  class = "n" > accumulator< / span > < span  class = "o" > .< / span > < span  class = "n" > to< / span > < span  class = "p" > (< / span > < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > float16< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # -----------------------------------------------------------< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # Write back the block of the output matrix C< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > offs_cm< / span >  < span  class = "o" > =< / span >  < span  class = "n" > pid_m< / span >  < span  class = "o" > *< / span >  < span  class = "n" > BLOCK_SIZE_M< / span >  < span  class = "o" > +< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > arange< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_M< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > offs_cn< / span >  < span  class = "o" > =< / span >  < span  class = "n" > pid_n< / span >  < span  class = "o" > *< / span >  < span  class = "n" > BLOCK_SIZE_N< / span >  < span  class = "o" > +< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > arange< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > BLOCK_SIZE_N< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > c_ptrs< / span >  < span  class = "o" > =< / span >  < span  class = "n" > c_ptr< / span >  < span  class = "o" > +< / span >  < span  class = "n" > stride_cm< / span >  < span  class = "o" > *< / span >  < span  class = "n" > offs_cm< / span > < span  class = "p" > [:,< / span >  < span  class = "kc" > None< / span > < span  class = "p" > ]< / span >  < span  class = "o" > +< / span >  < span  class = "n" > stride_cn< / span >  < span  class = "o" > *< / span >  < span  class = "n" > offs_cn< / span > < span  class = "p" > [< / span > < span  class = "kc" > None< / span > < span  class = "p" > ,< / span >  < span  class = "p" > :]< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > c_mask< / span >  < span  class = "o" > =< / span >  < span  class = "p" > (< / span > < span  class = "n" > offs_cm< / span > < span  class = "p" > [:,< / span >  < span  class = "kc" > None< / span > < span  class = "p" > ]< / span >  < span  class = "o" > < < / span >  < span  class = "n" > M< / span > < span  class = "p" > )< / span >  < span  class = "o" > & < / span >  < span  class = "p" > (< / span > < span  class = "n" > offs_cn< / span > < span  class = "p" > [< / span > < span  class = "kc" > None< / span > < span  class = "p" > ,< / span >  < span  class = "p" > :]< / span >  < span  class = "o" > < < / span >  < span  class = "n" > N< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > store< / span > < span  class = "p" > (< / span > < span  class = "n" > c_ptrs< / span > < span  class = "p" > ,< / span >  < span  class = "n" > c< / span > < span  class = "p" > ,< / span >  < span  class = "n" > mask< / span > < span  class = "o" > =< / span > < span  class = "n" > c_mask< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "c1" > # we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nd" > @triton< / span > < span  class = "o" > .< / span > < span  class = "n" > jit< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > leaky_relu< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > return< / span >  < span  class = "n" > tl< / span > < span  class = "o" > .< / span > < span  class = "n" > where< / span > < span  class = "p" > (< / span > < span  class = "n" > x< / span >  < span  class = "o" > > =< / span >  < span  class = "mi" > 0< / span > < span  class = "p" > ,< / span >  < span  class = "n" > x< / span > < span  class = "p" > ,< / span >  < span  class = "mf" > 0.01< / span >  < span  class = "o" > *< / span >  < span  class = "n" > x< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > We can now create a convenience wrapper function that only takes two input tensors
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-default notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "k" > def< / span >  < span  class = "nf" > matmul< / span > < span  class = "p" > (< / span > < span  class = "n" > a< / span > < span  class = "p" > ,< / span >  < span  class = "n" > b< / span > < span  class = "p" > ,< / span >  < span  class = "n" > activation< / span > < span  class = "o" > =< / span > < span  class = "kc" > None< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # checks constraints< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > assert< / span >  < span  class = "n" > a< / span > < span  class = "o" > .< / span > < span  class = "n" > shape< / span > < span  class = "p" > [< / span > < span  class = "mi" > 1< / span > < span  class = "p" > ]< / span >  < span  class = "o" > ==< / span >  < span  class = "n" > b< / span > < span  class = "o" > .< / span > < span  class = "n" > shape< / span > < span  class = "p" > [< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ],< / span >  < span  class = "s2" > " incompatible dimensions" < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > assert< / span >  < span  class = "n" > a< / span > < span  class = "o" > .< / span > < span  class = "n" > is_contiguous< / span > < span  class = "p" > (),< / span >  < span  class = "s2" > " matrix A must be contiguous" < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > assert< / span >  < span  class = "n" > b< / span > < span  class = "o" > .< / span > < span  class = "n" > is_contiguous< / span > < span  class = "p" > (),< / span >  < span  class = "s2" > " matrix B must be contiguous" < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > K< / span >  < span  class = "o" > =< / span >  < span  class = "n" > a< / span > < span  class = "o" > .< / span > < span  class = "n" > shape< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > K< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span >  < span  class = "o" > =< / span >  < span  class = "n" > b< / span > < span  class = "o" > .< / span > < span  class = "n" > shape< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > assert< / span >  < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > K< / span >  < span  class = "o" > %< / span >  < span  class = "mi" > 32< / span >  < span  class = "o" > ==< / span >  < span  class = "mi" > 0< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "p" > ),< / span >  < span  class = "s2" > " We don' t check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K" < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # allocates output< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > c< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > empty< / span > < span  class = "p" > ((< / span > < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > ),< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "n" > a< / span > < span  class = "o" > .< / span > < span  class = "n" > device< / span > < span  class = "p" > ,< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > a< / span > < span  class = "o" > .< / span > < span  class = "n" > dtype< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "c1" > # 1D launch kernel where each block gets its own program.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > grid< / span >  < span  class = "o" > =< / span >  < span  class = "k" > lambda< / span >  < span  class = "n" > META< / span > < span  class = "p" > :< / span >  < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > cdiv< / span > < span  class = "p" > (< / span > < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > META< / span > < span  class = "p" > [< / span > < span  class = "s1" > ' BLOCK_SIZE_M' < / span > < span  class = "p" > ])< / span >  < span  class = "o" > *< / span >  < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > cdiv< / span > < span  class = "p" > (< / span > < span  class = "n" > N< / span > < span  class = "p" > ,< / span >  < span  class = "n" > META< / span > < span  class = "p" > [< / span > < span  class = "s1" > ' BLOCK_SIZE_N' < / span > < span  class = "p" > ]),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > matmul_kernel< / span > < span  class = "p" > [< / span > < span  class = "n" > grid< / span > < span  class = "p" > ](< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > a< / span > < span  class = "p" > ,< / span >  < span  class = "n" > b< / span > < span  class = "p" > ,< / span >  < span  class = "n" > c< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > ,< / span >  < span  class = "n" > K< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > a< / span > < span  class = "o" > .< / span > < span  class = "n" > stride< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ),< / span >  < span  class = "n" > a< / span > < span  class = "o" > .< / span > < span  class = "n" > stride< / span > < span  class = "p" > (< / span > < span  class = "mi" > 1< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > b< / span > < span  class = "o" > .< / span > < span  class = "n" > stride< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ),< / span >  < span  class = "n" > b< / span > < span  class = "o" > .< / span > < span  class = "n" > stride< / span > < span  class = "p" > (< / span > < span  class = "mi" > 1< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > c< / span > < span  class = "o" > .< / span > < span  class = "n" > stride< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > ),< / span >  < span  class = "n" > c< / span > < span  class = "o" > .< / span > < span  class = "n" > stride< / span > < span  class = "p" > (< / span > < span  class = "mi" > 1< / span > < span  class = "p" > ),< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > ACTIVATION< / span > < span  class = "o" > =< / span > < span  class = "n" > activation< / span > < span  class = "p" > ,< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > return< / span >  < span  class = "n" > c< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "section"  id = "unit-test" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > Unit Test< a  class = "headerlink"  href = "#unit-test"  title = "Permalink to this headline" > ¶< / a > < / h2 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-default notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > manual_seed< / span > < span  class = "p" > (< / span > < span  class = "mi" > 0< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > a< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > randn< / span > < span  class = "p" > ((< / span > < span  class = "mi" > 512< / span > < span  class = "p" > ,< / span >  < span  class = "mi" > 512< / span > < span  class = "p" > ),< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > ,< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > float16< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > b< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > randn< / span > < span  class = "p" > ((< / span > < span  class = "mi" > 512< / span > < span  class = "p" > ,< / span >  < span  class = "mi" > 512< / span > < span  class = "p" > ),< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > ,< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > float16< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > triton_output< / span >  < span  class = "o" > =< / span >  < span  class = "n" > matmul< / span > < span  class = "p" > (< / span > < span  class = "n" > a< / span > < span  class = "p" > ,< / span >  < span  class = "n" > b< / span > < span  class = "p" > ,< / span >  < span  class = "n" > activation< / span > < span  class = "o" > =< / span > < span  class = "kc" > None< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > torch_output< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > matmul< / span > < span  class = "p" > (< / span > < span  class = "n" > a< / span > < span  class = "p" > ,< / span >  < span  class = "n" > b< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > print< / span > < span  class = "p" > (< / span > < span  class = "sa" > f< / span > < span  class = "s2" > " triton_output=< / span > < span  class = "si" > {< / span > < span  class = "n" > triton_output< / span > < span  class = "si" > }< / span > < span  class = "s2" > " < / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "nb" > print< / span > < span  class = "p" > (< / span > < span  class = "sa" > f< / span > < span  class = "s2" > " torch_output=< / span > < span  class = "si" > {< / span > < span  class = "n" > torch_output< / span > < span  class = "si" > }< / span > < span  class = "s2" > " < / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "k" > if< / span >  < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > allclose< / span > < span  class = "p" > (< / span > < span  class = "n" > triton_output< / span > < span  class = "p" > ,< / span >  < span  class = "n" > torch_output< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "nb" > print< / span > < span  class = "p" > (< / span > < span  class = "s2" > " ✅ Triton and Torch match" < / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "k" > else< / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "nb" > print< / span > < span  class = "p" > (< / span > < span  class = "s2" > " ❌ Triton and Torch differ" < / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "sphx-glr-script-out" > Out:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "sphx-glr-script-out highlight-none notranslate" > < div  class = "highlight" > < pre > < span > < / span > triton_output=tensor([[  1.1045, -36.9688,  31.4688,  ..., -11.3984,  24.4531, -32.3438],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        [  6.3555, -19.6094,  34.0938,  ...,  -5.8945,   5.2891,   6.8867],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        [-32.0625,   5.9492,  15.3984,  ..., -21.3906, -23.9844, -10.1328],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        ...,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        [ -5.7031,   7.4492,   8.2656,  ..., -10.6953, -40.0000,  17.7500],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        [ 25.5000,  24.3281,  -8.4688,  ..., -18.9375,  32.5312, -29.9219],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        [ -5.3477,   4.9844,  11.8906,  ...,   5.5898,   6.4023, -17.3125]],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								       device=' cuda:0' , dtype=torch.float16)
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								torch_output=tensor([[  1.1045, -36.9688,  31.4688,  ..., -11.3906,  24.4531, -32.3438],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        [  6.3516, -19.6094,  34.0938,  ...,  -5.8906,   5.2812,   6.8828],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        [-32.0625,   5.9531,  15.3984,  ..., -21.4062, -23.9844, -10.1328],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        ...,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        [ -5.7070,   7.4492,   8.2656,  ..., -10.6953, -40.0000,  17.7500],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        [ 25.5000,  24.3438,  -8.4609,  ..., -18.9375,  32.5312, -29.9219],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        [ -5.3477,   4.9805,  11.8828,  ...,   5.5859,   6.4023, -17.3125]],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								       device=' cuda:0' , dtype=torch.float16)
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								✅ Triton and Torch match
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "section"  id = "benchmark" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h2 > Benchmark< a  class = "headerlink"  href = "#benchmark"  title = "Permalink to this headline" > ¶< / a > < / h2 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "section"  id = "square-matrix-performance" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< h3 > Square Matrix Performance< a  class = "headerlink"  href = "#square-matrix-performance"  title = "Permalink to this headline" > ¶< / a > < / h3 > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "highlight-default notranslate" > < div  class = "highlight" > < pre > < span > < / span > < span  class = "nd" > @triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > perf_report< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > Benchmark< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > x_names< / span > < span  class = "o" > =< / span > < span  class = "p" > [< / span > < span  class = "s1" > ' M' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' N' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' K' < / span > < span  class = "p" > ],< / span >   < span  class = "c1" > # argument names to use as an x-axis for the plot< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > x_vals< / span > < span  class = "o" > =< / span > < span  class = "p" > [< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < span  class = "mi" > 128< / span >  < span  class = "o" > *< / span >  < span  class = "n" > i< / span >  < span  class = "k" > for< / span >  < span  class = "n" > i< / span >  < span  class = "ow" > in< / span >  < span  class = "nb" > range< / span > < span  class = "p" > (< / span > < span  class = "mi" > 2< / span > < span  class = "p" > ,< / span >  < span  class = "mi" > 33< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "p" > ],< / span >   < span  class = "c1" > # different possible values for `x_name`< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > line_arg< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' provider' < / span > < span  class = "p" > ,< / span >   < span  class = "c1" > # argument name whose value corresponds to a different line in the plot< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "c1" > # possible values for `line_arg``< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > line_vals< / span > < span  class = "o" > =< / span > < span  class = "p" > [< / span > < span  class = "s1" > ' cublas' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' cublas + relu' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' triton' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' triton + relu' < / span > < span  class = "p" > ],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "c1" > # label name for the lines< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > line_names< / span > < span  class = "o" > =< / span > < span  class = "p" > [< / span > < span  class = "s2" > " cuBLAS" < / span > < span  class = "p" > ,< / span >  < span  class = "s2" > " cuBLAS (+ torch.nn.LeakyReLU)" < / span > < span  class = "p" > ,< / span >  < span  class = "s2" > " Triton" < / span > < span  class = "p" > ,< / span >  < span  class = "s2" > " Triton (+ LeakyReLU)" < / span > < span  class = "p" > ],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "c1" > # line styles< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > styles< / span > < span  class = "o" > =< / span > < span  class = "p" > [(< / span > < span  class = "s1" > ' green' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' -' < / span > < span  class = "p" > ),< / span >  < span  class = "p" > (< / span > < span  class = "s1" > ' green' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' --' < / span > < span  class = "p" > ),< / span >  < span  class = "p" > (< / span > < span  class = "s1" > ' blue' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' -' < / span > < span  class = "p" > ),< / span >  < span  class = "p" > (< / span > < span  class = "s1" > ' blue' < / span > < span  class = "p" > ,< / span >  < span  class = "s1" > ' --' < / span > < span  class = "p" > )],< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > ylabel< / span > < span  class = "o" > =< / span > < span  class = "s2" > " TFLOPS" < / span > < span  class = "p" > ,< / span >   < span  class = "c1" > # label name for the y-axis< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > plot_name< / span > < span  class = "o" > =< / span > < span  class = "s2" > " matmul-performance" < / span > < span  class = "p" > ,< / span >   < span  class = "c1" > # name for the plot. Used also as a file name for saving the plot.< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > args< / span > < span  class = "o" > =< / span > < span  class = "p" > {},< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "k" > def< / span >  < span  class = "nf" > benchmark< / span > < span  class = "p" > (< / span > < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > ,< / span >  < span  class = "n" > K< / span > < span  class = "p" > ,< / span >  < span  class = "n" > provider< / span > < span  class = "p" > ):< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > a< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > randn< / span > < span  class = "p" > ((< / span > < span  class = "n" > M< / span > < span  class = "p" > ,< / span >  < span  class = "n" > K< / span > < span  class = "p" > ),< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > ,< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > float16< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > b< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > randn< / span > < span  class = "p" > ((< / span > < span  class = "n" > K< / span > < span  class = "p" > ,< / span >  < span  class = "n" > N< / span > < span  class = "p" > ),< / span >  < span  class = "n" > device< / span > < span  class = "o" > =< / span > < span  class = "s1" > ' cuda' < / span > < span  class = "p" > ,< / span >  < span  class = "n" > dtype< / span > < span  class = "o" > =< / span > < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > float16< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > if< / span >  < span  class = "n" > provider< / span >  < span  class = "o" > ==< / span >  < span  class = "s1" > ' cublas' < / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > ms< / span > < span  class = "p" > ,< / span >  < span  class = "n" > min_ms< / span > < span  class = "p" > ,< / span >  < span  class = "n" > max_ms< / span >  < span  class = "o" > =< / span >  < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > do_bench< / span > < span  class = "p" > (< / span > < span  class = "k" > lambda< / span > < span  class = "p" > :< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > matmul< / span > < span  class = "p" > (< / span > < span  class = "n" > a< / span > < span  class = "p" > ,< / span >  < span  class = "n" > b< / span > < span  class = "p" > ))< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > if< / span >  < span  class = "n" > provider< / span >  < span  class = "o" > ==< / span >  < span  class = "s1" > ' triton' < / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > ms< / span > < span  class = "p" > ,< / span >  < span  class = "n" > min_ms< / span > < span  class = "p" > ,< / span >  < span  class = "n" > max_ms< / span >  < span  class = "o" > =< / span >  < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > do_bench< / span > < span  class = "p" > (< / span > < span  class = "k" > lambda< / span > < span  class = "p" > :< / span >  < span  class = "n" > matmul< / span > < span  class = "p" > (< / span > < span  class = "n" > a< / span > < span  class = "p" > ,< / span >  < span  class = "n" > b< / span > < span  class = "p" > ))< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > if< / span >  < span  class = "n" > provider< / span >  < span  class = "o" > ==< / span >  < span  class = "s1" > ' cublas + relu' < / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > torch_relu< / span >  < span  class = "o" > =< / span >  < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > nn< / span > < span  class = "o" > .< / span > < span  class = "n" > ReLU< / span > < span  class = "p" > (< / span > < span  class = "n" > inplace< / span > < span  class = "o" > =< / span > < span  class = "kc" > True< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > ms< / span > < span  class = "p" > ,< / span >  < span  class = "n" > min_ms< / span > < span  class = "p" > ,< / span >  < span  class = "n" > max_ms< / span >  < span  class = "o" > =< / span >  < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > do_bench< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < span  class = "k" > lambda< / span > < span  class = "p" > :< / span >  < span  class = "n" > torch_relu< / span > < span  class = "p" > (< / span > < span  class = "n" > torch< / span > < span  class = "o" > .< / span > < span  class = "n" > matmul< / span > < span  class = "p" > (< / span > < span  class = "n" > a< / span > < span  class = "p" > ,< / span >  < span  class = "n" > b< / span > < span  class = "p" > ))< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > if< / span >  < span  class = "n" > provider< / span >  < span  class = "o" > ==< / span >  < span  class = "s1" > ' triton + relu' < / span > < span  class = "p" > :< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "n" > ms< / span > < span  class = "p" > ,< / span >  < span  class = "n" > min_ms< / span > < span  class = "p" > ,< / span >  < span  class = "n" > max_ms< / span >  < span  class = "o" > =< / span >  < span  class = "n" > triton< / span > < span  class = "o" > .< / span > < span  class = "n" > testing< / span > < span  class = "o" > .< / span > < span  class = "n" > do_bench< / span > < span  class = "p" > (< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < span  class = "k" > lambda< / span > < span  class = "p" > :< / span >  < span  class = "n" > matmul< / span > < span  class = "p" > (< / span > < span  class = "n" > a< / span > < span  class = "p" > ,< / span >  < span  class = "n" > b< / span > < span  class = "p" > ,< / span >  < span  class = "n" > activation< / span > < span  class = "o" > =< / span > < span  class = "n" > leaky_relu< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "n" > perf< / span >  < span  class = "o" > =< / span >  < span  class = "k" > lambda< / span >  < span  class = "n" > ms< / span > < span  class = "p" > :< / span >  < span  class = "mi" > 2< / span >  < span  class = "o" > *< / span >  < span  class = "n" > M< / span >  < span  class = "o" > *< / span >  < span  class = "n" > N< / span >  < span  class = "o" > *< / span >  < span  class = "n" > K< / span >  < span  class = "o" > *< / span >  < span  class = "mf" > 1e-12< / span >  < span  class = "o" > /< / span >  < span  class = "p" > (< / span > < span  class = "n" > ms< / span >  < span  class = "o" > *< / span >  < span  class = "mf" > 1e-3< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "k" > return< / span >  < span  class = "n" > perf< / span > < span  class = "p" > (< / span > < span  class = "n" > ms< / span > < span  class = "p" > ),< / span >  < span  class = "n" > perf< / span > < span  class = "p" > (< / span > < span  class = "n" > max_ms< / span > < span  class = "p" > ),< / span >  < span  class = "n" > perf< / span > < span  class = "p" > (< / span > < span  class = "n" > min_ms< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< span  class = "n" > benchmark< / span > < span  class = "o" > .< / span > < span  class = "n" > run< / span > < span  class = "p" > (< / span > < span  class = "n" > show_plots< / span > < span  class = "o" > =< / span > < span  class = "kc" > True< / span > < span  class = "p" > ,< / span >  < span  class = "n" > print_data< / span > < span  class = "o" > =< / span > < span  class = "kc" > True< / span > < span  class = "p" > )< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< img  alt = "03 matrix multiplication"  class = "sphx-glr-single-img"  src = "../../_images/sphx_glr_03-matrix-multiplication_001.png"  / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "sphx-glr-script-out" > Out:< / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "sphx-glr-script-out highlight-none notranslate" > < div  class = "highlight" > < pre > < span > < / span > matmul-performance:
							 
						 
					
						
							
								
									
										
										
										
											2022-04-14 00:44:57 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								         M     cuBLAS  ...     Triton  Triton (+ LeakyReLU)
							 
						 
					
						
							
								
									
										
										
										
											2022-04-26 00:43:32 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								0    256.0   2.730667  ...   2.978909              2.978909
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								1    384.0   7.372800  ...   8.507077              8.507077
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								2    512.0  14.563555  ...  15.420235             16.384000
							 
						 
					
						
							
								
									
										
										
										
											2022-04-14 00:44:57 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								3    640.0  22.260869  ...  24.380953             24.380953
							 
						 
					
						
							
								
									
										
										
										
											2022-04-25 00:41:43 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								4    768.0  32.768000  ...  35.389441             34.028308
							 
						 
					
						
							
								
									
										
										
										
											2022-04-26 00:43:32 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								5    896.0  37.971025  ...  40.140799             39.025776
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								6   1024.0  49.932191  ...  53.773130             52.428801
							 
						 
					
						
							
								
									
										
										
										
											2022-04-24 00:44:07 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								7   1152.0  45.242181  ...  48.161033             47.396572
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								8   1280.0  51.200001  ...  57.690139             57.690139
							 
						 
					
						
							
								
									
										
										
										
											2022-04-26 00:43:32 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								9   1408.0  64.138541  ...  69.009825             67.305878
							 
						 
					
						
							
								
									
										
										
										
											2022-04-24 00:44:07 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								10  1536.0  79.526831  ...  79.526831             79.526831
							 
						 
					
						
							
								
									
										
										
										
											2022-04-26 00:43:32 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								11  1664.0  62.929456  ...  63.372618             62.929456
							 
						 
					
						
							
								
									
										
										
										
											2022-04-25 00:41:43 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								12  1792.0  72.983276  ...  63.499573             63.142831
							 
						 
					
						
							
								
									
										
										
										
											2022-04-26 00:43:32 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								13  1920.0  68.776119  ...  71.626943             71.257735
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								14  2048.0  73.584279  ...  78.398206             78.033565
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								15  2176.0  83.500614  ...  87.115360             86.739860
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								16  2304.0  68.251065  ...  77.810656             77.558029
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								17  2432.0  71.125224  ...  75.726318             75.522751
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								18  2560.0  77.833728  ...  82.331658             81.920002
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								19  2688.0  83.737433  ...  90.532356             90.316801
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								20  2816.0  82.290955  ...  83.712490             84.197315
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								21  2944.0  82.646820  ...  81.967162             83.477440
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								22  3072.0  82.062468  ...  85.662786             89.030036
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								23  3200.0  84.210524  ...  97.116842             95.952022
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								24  3328.0  83.905938  ...  86.946008             86.736504
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								25  3456.0  79.196043  ...  86.689860             91.407671
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								26  3584.0  87.211821  ...  94.947616             97.840469
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								27  3712.0  85.896254  ...  83.005689             88.404730
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								28  3840.0  81.738356  ...  88.297007             91.473945
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								29  3968.0  88.040360  ...  92.093539             84.797731
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								30  4096.0  93.336389  ...  91.491294             88.185107
							 
						 
					
						
							
								
									
										
										
										
											2022-02-09 07:15:50 +00:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								[31 rows x 5 columns]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / pre > < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
									
										
										
										
											2022-04-26 00:43:32 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "sphx-glr-timing" > < strong > Total running time of the script:< / strong >  ( 6 minutes  27.164 seconds)< / p > 
							 
						 
					
						
							
								
									
										
										
										
											2022-02-09 07:15:50 +00:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "sphx-glr-footer class sphx-glr-footer-example docutils container"  id = "sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "sphx-glr-download sphx-glr-download-python docutils container" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > < a  class = "reference download internal"  download = ""  href = "../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py" > < code  class = "xref download docutils literal notranslate" > < span  class = "pre" > Download< / span >  < span  class = "pre" > Python< / span >  < span  class = "pre" > source< / span >  < span  class = "pre" > code:< / span >  < span  class = "pre" > 03-matrix-multiplication.py< / span > < / code > < / a > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "sphx-glr-download sphx-glr-download-jupyter docutils container" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p > < a  class = "reference download internal"  download = ""  href = "../../_downloads/b51b68bc1c6b1a5e509f67800b6235af/03-matrix-multiplication.ipynb" > < code  class = "xref download docutils literal notranslate" > < span  class = "pre" > Download< / span >  < span  class = "pre" > Jupyter< / span >  < span  class = "pre" > notebook:< / span >  < span  class = "pre" > 03-matrix-multiplication.ipynb< / span > < / code > < / a > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< p  class = "sphx-glr-signature" > < a  class = "reference external"  href = "https://sphinx-gallery.github.io" > Gallery generated by Sphinx-Gallery< / a > < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								           < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								           
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          < footer > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < div  class = "rst-footer-buttons"  role = "navigation"  aria-label = "footer navigation" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < a  href = "04-low-memory-dropout.html"  class = "btn btn-neutral float-right"  title = "Low-Memory Dropout"  accesskey = "n"  rel = "next" > Next < span  class = "fa fa-arrow-circle-right"  aria-hidden = "true" > < / span > < / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < a  href = "02-fused-softmax.html"  class = "btn btn-neutral float-left"  title = "Fused Softmax"  accesskey = "p"  rel = "prev" > < span  class = "fa fa-arrow-circle-left"  aria-hidden = "true" > < / span >  Previous< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < hr / > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < div  role = "contentinfo" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        ©  Copyright 2020, Philippe Tillet.
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / p > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    Built with < a  href = "https://www.sphinx-doc.org/" > Sphinx< / a >  using a
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < a  href = "https://github.com/readthedocs/sphinx_rtd_theme" > theme< / a > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    provided by < a  href = "https://readthedocs.org" > Read the Docs< / a > . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / footer > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / section > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< div  class = "rst-versions"  data-toggle = "rst-versions"  role = "note"  aria-label = "versions" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < span  class = "rst-current-version"  data-toggle = "rst-current-version" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "fa fa-book" >  Other Versions< / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        v: master
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < span  class = "fa fa-caret-down" > < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / span > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < div  class = "rst-other-versions" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < dl > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < dt > Tags< / dt > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < dd > < a  href = "../../../v1.1.2/index.html" > v1.1.2< / a > < / dd > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < / dl > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < dl > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < dt > Branches< / dt > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								            < dd > < a  href = "03-matrix-multiplication.html" > master< / a > < / dd > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								        < / dl > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    < / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / div > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < script  type = "text/javascript" > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      jQuery(function () {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								          SphinxRtdTheme.Navigation.enable(true);
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								      });
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  < / script > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								   
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / body > 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
								
									
								 
							
							
								< / html >