684 lines
73 KiB
HTML
684 lines
73 KiB
HTML
|
||
|
||
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" >
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
|
||
<title>Matrix Multiplication — Triton documentation</title>
|
||
|
||
|
||
|
||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<!--[if lt IE 9]>
|
||
<script src="../../_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
|
||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||
<script src="../../_static/jquery.js"></script>
|
||
<script src="../../_static/underscore.js"></script>
|
||
<script src="../../_static/doctools.js"></script>
|
||
|
||
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
||
|
||
|
||
<link rel="index" title="Index" href="../../genindex.html" />
|
||
<link rel="search" title="Search" href="../../search.html" />
|
||
<link rel="next" title="Low-Memory Dropout" href="04-low-memory-dropout.html" />
|
||
<link rel="prev" title="Fused Softmax" href="02-fused-softmax.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
|
||
|
||
<div class="wy-grid-for-nav">
|
||
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../../index.html" class="icon icon-home"> Triton
|
||
|
||
|
||
|
||
</a>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
|
||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||
<ul class="current">
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
|
||
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
|
||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Matrix Multiplication</a><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="#motivations">Motivations</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a><ul>
|
||
<li class="toctree-l4"><a class="reference internal" href="#pointer-arithmetics">Pointer Arithmetics</a></li>
|
||
<li class="toctree-l4"><a class="reference internal" href="#l2-cache-optimizations">L2 Cache Optimizations</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#final-result">Final Result</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a><ul>
|
||
<li class="toctree-l4"><a class="reference internal" href="#square-matrix-performance">Square Matrix Performance</a></li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||
</ul>
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||
|
||
|
||
<nav class="wy-nav-top" aria-label="top navigation">
|
||
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../../index.html">Triton</a>
|
||
|
||
</nav>
|
||
|
||
|
||
<div class="wy-nav-content">
|
||
|
||
<div class="rst-content">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||
|
||
<ul class="wy-breadcrumbs">
|
||
|
||
<li><a href="../../index.html" class="icon icon-home"></a> »</li>
|
||
|
||
<li><a href="index.html">Tutorials</a> »</li>
|
||
|
||
<li>Matrix Multiplication</li>
|
||
|
||
|
||
<li class="wy-breadcrumbs-aside">
|
||
|
||
|
||
<a href="../../_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt" rel="nofollow"> View page source</a>
|
||
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
|
||
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<div class="sphx-glr-download-link-note admonition note">
|
||
<p class="admonition-title">Note</p>
|
||
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">here</span></a>
|
||
to download the full example code</p>
|
||
</div>
|
||
<div class="sphx-glr-example-title section" id="matrix-multiplication">
|
||
<span id="sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"></span><h1>Matrix Multiplication<a class="headerlink" href="#matrix-multiplication" title="Permalink to this headline">¶</a></h1>
|
||
<p>In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication
|
||
kernel that achieves performance on par with cuBLAS.
|
||
You will specifically learn about:</p>
|
||
<ul class="simple">
|
||
<li><p>Block-level matrix multiplications</p></li>
|
||
<li><p>Multi-dimensional pointer arithmetic</p></li>
|
||
<li><p>Program re-ordering for improved L2 cache hit rate</p></li>
|
||
<li><p>Automatic performance tuning</p></li>
|
||
</ul>
|
||
<div class="section" id="motivations">
|
||
<h2>Motivations<a class="headerlink" href="#motivations" title="Permalink to this headline">¶</a></h2>
|
||
<p>Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||
They are notoriously hard to optimize, hence their implementation is generally done by
|
||
hardware vendors themselves as part of so-called “kernel libraries” (e.g., cuBLAS).
|
||
Unfortunately, these libraries are often proprietary and cannot be easily customized
|
||
to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
|
||
In this tutorial, you will learn how to implement efficient matrix multiplications by
|
||
yourself with Triton, in a way that is easy to customize and extend.</p>
|
||
<p>Roughly speaking, the kernel that we will write will implement the following blocked
|
||
algorithm to multiply a (M, K) by a (K, N) matrix:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># do in parallel</span>
|
||
<span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">):</span>
|
||
<span class="c1"># do in parallel</span>
|
||
<span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">):</span>
|
||
<span class="n">acc</span> <span class="o">=</span> <span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">):</span>
|
||
<span class="n">a</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">]</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">,</span> <span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">]</span>
|
||
<span class="n">acc</span> <span class="o">+=</span> <span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
<span class="n">C</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">]</span> <span class="o">=</span> <span class="n">acc</span><span class="p">;</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance.</p>
|
||
</div>
|
||
<div class="section" id="compute-kernel">
|
||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||
<p>The above algorithm is, actually, fairly straightforward to implement in Triton.
|
||
The main difficulty comes from the computation of the memory locations at which blocks
|
||
of <code class="code docutils literal notranslate"><span class="pre">A</span></code> and <code class="code docutils literal notranslate"><span class="pre">B</span></code> must be read in the inner loop. For that, we need
|
||
multi-dimensional pointer arithmetics.</p>
|
||
<div class="section" id="pointer-arithmetics">
|
||
<h3>Pointer Arithmetics<a class="headerlink" href="#pointer-arithmetics" title="Permalink to this headline">¶</a></h3>
|
||
<p>For a row-major 2D tensor <code class="code docutils literal notranslate"><span class="pre">X</span></code>, the memory location of <code class="code docutils literal notranslate"><span class="pre">X[i,</span> <span class="pre">j]</span></code> is given b
|
||
y <code class="code docutils literal notranslate"><span class="pre">&X[i,</span> <span class="pre">j]</span> <span class="pre">=</span> <span class="pre">X</span> <span class="pre">+</span> <span class="pre">i*stride_xi</span> <span class="pre">+</span> <span class="pre">j*stride_xj</span></code>.
|
||
Therefore, blocks of pointers for <code class="code docutils literal notranslate"><span class="pre">A[m</span> <span class="pre">:</span> <span class="pre">m+BLOCK_SIZE_M,</span> <span class="pre">k:k+BLOCK_SIZE_K]</span></code> and
|
||
<code class="code docutils literal notranslate"><span class="pre">B[k</span> <span class="pre">:</span> <span class="pre">k+BLOCK_SIZE_K,</span> <span class="pre">n</span> <span class="pre">:</span> <span class="pre">n+BLOCK_SIZE_N]</span></code> can be defined in pseudo-code as:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">&</span><span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span><span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">]</span> <span class="o">=</span> <span class="n">a_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
|
||
<span class="o">&</span><span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span><span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">]</span> <span class="o">=</span> <span class="n">b_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>Which means that pointers for blocks of A and B can be initialized (i.e., <code class="code docutils literal notranslate"><span class="pre">k=0</span></code>) in Triton as:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">offs_am</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
|
||
<span class="n">offs_bn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
|
||
<span class="n">offs_k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">)</span>
|
||
<span class="n">a_ptrs</span> <span class="o">=</span> <span class="n">a_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_am</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">stride_am</span> <span class="o">+</span> <span class="n">offs_k</span> <span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">stride_ak</span><span class="p">)</span>
|
||
<span class="n">b_ptrs</span> <span class="o">=</span> <span class="n">b_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_k</span> <span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">stride_bk</span> <span class="o">+</span> <span class="n">offs_bn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">stride_bn</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>And then updated in the inner loop as follows:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pa</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_ak</span><span class="p">;</span>
|
||
<span class="n">pb</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_bk</span><span class="p">;</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
</div>
|
||
<div class="section" id="l2-cache-optimizations">
|
||
<h3>L2 Cache Optimizations<a class="headerlink" href="#l2-cache-optimizations" title="Permalink to this headline">¶</a></h3>
|
||
<p>As mentioned above, each program instance computes a <code class="code docutils literal notranslate"><span class="pre">[BLOCK_SIZE_M,</span> <span class="pre">BLOCK_SIZE_N]</span></code>
|
||
block of <code class="code docutils literal notranslate"><span class="pre">C</span></code>.
|
||
It is important to remember that the order in which these blocks are computed does
|
||
matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
|
||
a simple row-major ordering</p>
|
||
<blockquote>
|
||
<div><div class="highlight-Python notranslate"><div class="highlight"><pre><span></span><span class="n">pid</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||
<span class="n">grid_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_SIZE_M</span><span class="p">;</span>
|
||
<span class="n">grid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_SIZE_N</span><span class="p">;</span>
|
||
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">/</span> <span class="n">grid_n</span><span class="p">;</span>
|
||
<span class="n">pid_n</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">%</span> <span class="n">grid_n</span><span class="p">;</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>is just not going to cut it.</p>
|
||
<p>One possible solution is to launch blocks in an order that promotes data reuse.
|
||
This can be done by ‘super-grouping’ blocks in groups of <code class="code docutils literal notranslate"><span class="pre">GROUP_M</span></code> rows before
|
||
switching to the next column:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># program ID</span>
|
||
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="c1"># number of program ids along the M axis</span>
|
||
<span class="n">num_pid_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
|
||
<span class="c1"># number of programs ids along the N axis</span>
|
||
<span class="n">num_pid_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
|
||
<span class="c1"># number of programs in group</span>
|
||
<span class="n">num_pid_in_group</span> <span class="o">=</span> <span class="n">GROUP_SIZE_M</span> <span class="o">*</span> <span class="n">num_pid_n</span>
|
||
<span class="c1"># id of the group this program is in</span>
|
||
<span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">//</span> <span class="n">num_pid_in_group</span>
|
||
<span class="c1"># row-id of the first program in the group</span>
|
||
<span class="n">first_pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_SIZE_M</span>
|
||
<span class="c1"># if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller</span>
|
||
<span class="n">group_size_m</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">num_pid_m</span> <span class="o">-</span> <span class="n">first_pid_m</span><span class="p">,</span> <span class="n">GROUP_SIZE_M</span><span class="p">)</span>
|
||
<span class="c1"># *within groups*, programs are ordered in a column-major order</span>
|
||
<span class="c1"># row-id of the program in the *launch grid*</span>
|
||
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">first_pid_m</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size_m</span><span class="p">)</span>
|
||
<span class="c1"># col-id of the program in the *launch grid*</span>
|
||
<span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">num_pid_in_group</span><span class="p">)</span> <span class="o">//</span> <span class="n">group_size_m</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
||
we can see that if we compute the output in row-major ordering, we need to load 90
|
||
blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
|
||
ordering, we only need to load 54 blocks.</p>
|
||
<blockquote>
|
||
<div><img alt="../../_images/grouped_vs_row_major_ordering.png" src="../../_images/grouped_vs_row_major_ordering.png" />
|
||
</div></blockquote>
|
||
<p>In practice, this can improve the performance of our matrix multiplication kernel by
|
||
more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</p>
|
||
</div>
|
||
</div>
|
||
<div class="section" id="final-result">
|
||
<h2>Final Result<a class="headerlink" href="#final-result" title="Permalink to this headline">¶</a></h2>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||
|
||
<span class="kn">import</span> <span class="nn">triton</span>
|
||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||
|
||
<span class="c1"># %</span>
|
||
<span class="c1"># :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune`</span>
|
||
<span class="c1"># decorator, which consumes:</span>
|
||
<span class="c1"># - A list of :code:`triton.Config` objects that define different configurations of</span>
|
||
<span class="c1"># meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try</span>
|
||
<span class="c1"># - An autotuning *key* whose change in values will trigger evaluation of all the</span>
|
||
<span class="c1"># provided configs</span>
|
||
|
||
|
||
<span class="nd">@triton</span><span class="o">.</span><span class="n">autotune</span><span class="p">(</span>
|
||
<span class="n">configs</span><span class="o">=</span><span class="p">[</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
|
||
<span class="p">],</span>
|
||
<span class="n">key</span><span class="o">=</span><span class="p">[</span><span class="s1">'M'</span><span class="p">,</span> <span class="s1">'N'</span><span class="p">,</span> <span class="s1">'K'</span><span class="p">],</span>
|
||
<span class="p">)</span>
|
||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||
<span class="k">def</span> <span class="nf">matmul_kernel</span><span class="p">(</span>
|
||
<span class="c1"># Pointers to matrices</span>
|
||
<span class="n">a_ptr</span><span class="p">,</span> <span class="n">b_ptr</span><span class="p">,</span> <span class="n">c_ptr</span><span class="p">,</span>
|
||
<span class="c1"># Matrix dimensions</span>
|
||
<span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span>
|
||
<span class="c1"># The stride variables represent how much to increase the ptr by when moving by 1</span>
|
||
<span class="c1"># element in a particular dimension. E.g. stride_am is how much to increase a_ptr</span>
|
||
<span class="c1"># by to get the element one row down (A has M rows)</span>
|
||
<span class="n">stride_am</span><span class="p">,</span> <span class="n">stride_ak</span><span class="p">,</span>
|
||
<span class="n">stride_bk</span><span class="p">,</span> <span class="n">stride_bn</span><span class="p">,</span>
|
||
<span class="n">stride_cm</span><span class="p">,</span> <span class="n">stride_cn</span><span class="p">,</span>
|
||
<span class="c1"># Meta-parameters</span>
|
||
<span class="n">BLOCK_SIZE_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
|
||
<span class="n">GROUP_SIZE_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
|
||
<span class="n">ACTIVATION</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
|
||
<span class="p">):</span>
|
||
<span class="sd">"""Kernel for computing the matmul C = A x B.</span>
|
||
<span class="sd"> A has shape (M, K), B has shape (K, N) and C has shape (M, N)</span>
|
||
<span class="sd"> """</span>
|
||
<span class="c1"># -----------------------------------------------------------</span>
|
||
<span class="c1"># Map program ids `pid` to the block of C it should compute.</span>
|
||
<span class="c1"># This is done in a grouped ordering to promote L2 data reuse</span>
|
||
<span class="c1"># See above `L2 Cache Optimizations` section for details</span>
|
||
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">num_pid_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
|
||
<span class="n">num_pid_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
|
||
<span class="n">num_pid_in_group</span> <span class="o">=</span> <span class="n">GROUP_SIZE_M</span> <span class="o">*</span> <span class="n">num_pid_n</span>
|
||
<span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">//</span> <span class="n">num_pid_in_group</span>
|
||
<span class="n">first_pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_SIZE_M</span>
|
||
<span class="n">group_size_m</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">num_pid_m</span> <span class="o">-</span> <span class="n">first_pid_m</span><span class="p">,</span> <span class="n">GROUP_SIZE_M</span><span class="p">)</span>
|
||
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">first_pid_m</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size_m</span><span class="p">)</span>
|
||
<span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">num_pid_in_group</span><span class="p">)</span> <span class="o">//</span> <span class="n">group_size_m</span>
|
||
|
||
<span class="c1"># ----------------------------------------------------------</span>
|
||
<span class="c1"># Create pointers for the first blocks of A and B.</span>
|
||
<span class="c1"># We will advance this pointer as we move in the K direction</span>
|
||
<span class="c1"># and accumulate</span>
|
||
<span class="c1"># a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers</span>
|
||
<span class="c1"># b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers</span>
|
||
<span class="c1"># see above `Pointer Arithmetics` section for details</span>
|
||
<span class="n">offs_am</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
|
||
<span class="n">offs_bn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
|
||
<span class="n">offs_k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">)</span>
|
||
<span class="n">a_ptrs</span> <span class="o">=</span> <span class="n">a_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_am</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_am</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_ak</span><span class="p">)</span>
|
||
<span class="n">b_ptrs</span> <span class="o">=</span> <span class="n">b_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_k</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_bk</span> <span class="o">+</span> <span class="n">offs_bn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_bn</span><span class="p">)</span>
|
||
|
||
<span class="c1"># -----------------------------------------------------------</span>
|
||
<span class="c1"># Iterate to compute a block of the C matrix</span>
|
||
<span class="c1"># We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block</span>
|
||
<span class="c1"># of fp32 values for higher accuracy.</span>
|
||
<span class="c1"># `accumulator` will be converted back to fp16 after the loop</span>
|
||
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">):</span>
|
||
<span class="c1"># Note that for simplicity, we don't apply a mask here.</span>
|
||
<span class="c1"># This means that if K is not a multiple of BLOCK_SIZE_K,</span>
|
||
<span class="c1"># this will access out-of-bounds memory and produce an</span>
|
||
<span class="c1"># error or (worse!) incorrect results.</span>
|
||
<span class="n">a</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">a_ptrs</span><span class="p">)</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">b_ptrs</span><span class="p">)</span>
|
||
<span class="c1"># We accumulate along the K dimension</span>
|
||
<span class="n">accumulator</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
<span class="c1"># Advance the ptrs to the next K block</span>
|
||
<span class="n">a_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_ak</span>
|
||
<span class="n">b_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_bk</span>
|
||
<span class="c1"># you can fuse arbitrary activation functions here</span>
|
||
<span class="c1"># while the accumulator is still in FP32!</span>
|
||
<span class="k">if</span> <span class="n">ACTIVATION</span> <span class="o">==</span> <span class="s2">"leaky_relu"</span><span class="p">:</span>
|
||
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">leaky_relu</span><span class="p">(</span><span class="n">accumulator</span><span class="p">)</span>
|
||
<span class="n">c</span> <span class="o">=</span> <span class="n">accumulator</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
|
||
<span class="c1"># -----------------------------------------------------------</span>
|
||
<span class="c1"># Write back the block of the output matrix C</span>
|
||
<span class="n">offs_cm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
|
||
<span class="n">offs_cn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
|
||
<span class="n">c_ptrs</span> <span class="o">=</span> <span class="n">c_ptr</span> <span class="o">+</span> <span class="n">stride_cm</span> <span class="o">*</span> <span class="n">offs_cm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">+</span> <span class="n">stride_cn</span> <span class="o">*</span> <span class="n">offs_cn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
|
||
<span class="n">c_mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">offs_cm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o"><</span> <span class="n">M</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">offs_cn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">c_ptrs</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">c_mask</span><span class="p">)</span>
|
||
|
||
|
||
<span class="c1"># we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`</span>
|
||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||
<span class="k">def</span> <span class="nf">leaky_relu</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="mi">1</span>
|
||
<span class="k">return</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="mf">0.01</span> <span class="o">*</span> <span class="n">x</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>We can now create a convenience wrapper function that only takes two input tensors
|
||
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">""</span><span class="p">):</span>
|
||
<span class="c1"># checks constraints</span>
|
||
<span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">"incompatible dimensions"</span>
|
||
<span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(),</span> <span class="s2">"matrix A must be contiguous"</span>
|
||
<span class="k">assert</span> <span class="n">b</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(),</span> <span class="s2">"matrix B must be contiguous"</span>
|
||
<span class="n">M</span><span class="p">,</span> <span class="n">K</span> <span class="o">=</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span>
|
||
<span class="n">K</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span>
|
||
<span class="k">assert</span> <span class="p">(</span>
|
||
<span class="n">K</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">==</span> <span class="mi">0</span>
|
||
<span class="p">),</span> <span class="s2">"We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"</span>
|
||
<span class="c1"># allocates output</span>
|
||
<span class="n">c</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="c1"># 1D launch kernel where each block gets its own program.</span>
|
||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">META</span><span class="p">:</span> <span class="p">(</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">])</span> <span class="o">*</span> <span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_N'</span><span class="p">]),</span>
|
||
<span class="p">)</span>
|
||
<span class="n">matmul_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
|
||
<span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span>
|
||
<span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span>
|
||
<span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">ACTIVATION</span><span class="o">=</span><span class="n">activation</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">c</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="section" id="unit-test">
|
||
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline">¶</a></h2>
|
||
<p>We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">triton_output</span> <span class="o">=</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
<span class="n">torch_output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"triton_output=</span><span class="si">{</span><span class="n">triton_output</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"torch_output=</span><span class="si">{</span><span class="n">torch_output</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">triton_output</span><span class="p">,</span> <span class="n">torch_output</span><span class="p">):</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="s2">"✅ Triton and Torch match"</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="s2">"❌ Triton and Torch differ"</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p class="sphx-glr-script-out">Out:</p>
|
||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>triton_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
|
||
[ 6.3555, -19.6094, 34.0938, ..., -5.8945, 5.2891, 6.8867],
|
||
[-32.0625, 5.9492, 15.3984, ..., -21.3906, -23.9844, -10.1328],
|
||
...,
|
||
[ -5.7031, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
|
||
[ 25.5000, 24.3281, -8.4688, ..., -18.9375, 32.5312, -29.9219],
|
||
[ -5.3477, 4.9844, 11.8906, ..., 5.5898, 6.4023, -17.3125]],
|
||
device='cuda:0', dtype=torch.float16)
|
||
torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
|
||
[ 6.3516, -19.6094, 34.0938, ..., -5.8906, 5.2812, 6.8828],
|
||
[-32.0625, 5.9531, 15.3984, ..., -21.4062, -23.9844, -10.1328],
|
||
...,
|
||
[ -5.7070, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
|
||
[ 25.5000, 24.3438, -8.4609, ..., -18.9375, 32.5312, -29.9219],
|
||
[ -5.3477, 4.9805, 11.8828, ..., 5.5859, 6.4023, -17.3125]],
|
||
device='cuda:0', dtype=torch.float16)
|
||
✅ Triton and Torch match
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="section" id="benchmark">
|
||
<h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline">¶</a></h2>
|
||
<div class="section" id="square-matrix-performance">
|
||
<h3>Square Matrix Performance<a class="headerlink" href="#square-matrix-performance" title="Permalink to this headline">¶</a></h3>
|
||
<p>We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
|
||
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'M'</span><span class="p">,</span> <span class="s1">'N'</span><span class="p">,</span> <span class="s1">'K'</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
|
||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span>
|
||
<span class="mi">128</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">33</span><span class="p">)</span>
|
||
<span class="p">],</span> <span class="c1"># different possible values for `x_name`</span>
|
||
<span class="n">line_arg</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||
<span class="c1"># possible values for `line_arg``</span>
|
||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'cublas'</span><span class="p">,</span> <span class="s1">'cublas + relu'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'triton + relu'</span><span class="p">],</span>
|
||
<span class="c1"># label name for the lines</span>
|
||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"cuBLAS"</span><span class="p">,</span> <span class="s2">"cuBLAS (+ torch.nn.LeakyReLU)"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">,</span> <span class="s2">"Triton (+ LeakyReLU)"</span><span class="p">],</span>
|
||
<span class="c1"># line styles</span>
|
||
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'--'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'--'</span><span class="p">)],</span>
|
||
<span class="n">ylabel</span><span class="o">=</span><span class="s2">"TFLOPS"</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
||
<span class="n">plot_name</span><span class="o">=</span><span class="s2">"matmul-performance"</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
||
<span class="n">args</span><span class="o">=</span><span class="p">{},</span>
|
||
<span class="p">)</span>
|
||
<span class="p">)</span>
|
||
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
|
||
<span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">K</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'cublas'</span><span class="p">:</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'triton'</span><span class="p">:</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'cublas + relu'</span><span class="p">:</span>
|
||
<span class="n">torch_relu</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span>
|
||
<span class="k">lambda</span><span class="p">:</span> <span class="n">torch_relu</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'triton + relu'</span><span class="p">:</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span>
|
||
<span class="k">lambda</span><span class="p">:</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"leaky_relu"</span><span class="p">)</span>
|
||
<span class="p">)</span>
|
||
<span class="n">perf</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">M</span> <span class="o">*</span> <span class="n">N</span> <span class="o">*</span> <span class="n">K</span> <span class="o">*</span> <span class="mf">1e-12</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">perf</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
|
||
|
||
|
||
<span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<img alt="03 matrix multiplication" class="sphx-glr-single-img" src="../../_images/sphx_glr_03-matrix-multiplication_001.png" />
|
||
<p class="sphx-glr-script-out">Out:</p>
|
||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>matmul-performance:
|
||
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
||
0 256.0 2.978909 ... 3.276800 2.978909
|
||
1 384.0 7.372800 ... 8.507077 8.507077
|
||
2 512.0 14.563555 ... 16.384000 16.384000
|
||
3 640.0 22.260869 ... 24.380953 24.380953
|
||
4 768.0 32.768000 ... 35.389441 34.028308
|
||
5 896.0 39.025776 ... 40.140799 39.025776
|
||
6 1024.0 51.150050 ... 53.773130 52.428801
|
||
7 1152.0 45.242181 ... 47.396572 47.396572
|
||
8 1280.0 51.200001 ... 57.690139 57.690139
|
||
9 1408.0 64.138541 ... 68.147202 67.305878
|
||
10 1536.0 80.430545 ... 81.355034 79.526831
|
||
11 1664.0 62.929456 ... 63.372618 62.492442
|
||
12 1792.0 72.512412 ... 73.460287 59.467852
|
||
13 1920.0 69.120002 ... 71.626943 71.257735
|
||
14 2048.0 73.908442 ... 78.398206 77.314362
|
||
15 2176.0 83.500614 ... 87.115360 85.998493
|
||
16 2304.0 68.446623 ... 78.064941 77.307030
|
||
17 2432.0 71.305746 ... 85.393507 75.118889
|
||
18 2560.0 77.833728 ... 82.956960 81.715711
|
||
19 2688.0 83.922689 ... 90.748936 89.888756
|
||
20 2816.0 79.879498 ... 84.360174 83.873477
|
||
21 2944.0 82.373605 ... 83.337844 82.373605
|
||
22 3072.0 82.540970 ... 89.593522 88.750943
|
||
23 3200.0 84.544253 ... 96.603776 94.604578
|
||
24 3328.0 82.181847 ... 84.795401 84.397770
|
||
25 3456.0 81.435930 ... 91.928814 91.097818
|
||
26 3584.0 86.665439 ... 94.349836 94.548254
|
||
27 3712.0 85.822459 ... 86.641231 87.860458
|
||
28 3840.0 83.277102 ... 93.484358 85.730230
|
||
29 3968.0 92.512459 ... 80.917732 77.648067
|
||
30 4096.0 87.552332 ... 93.858555 89.299883
|
||
|
||
[31 rows x 5 columns]
|
||
</pre></div>
|
||
</div>
|
||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 7 minutes 21.249 seconds)</p>
|
||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
|
||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
|
||
</div>
|
||
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
|
||
<p><a class="reference download internal" download="" href="../../_downloads/b51b68bc1c6b1a5e509f67800b6235af/03-matrix-multiplication.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">03-matrix-multiplication.ipynb</span></code></a></p>
|
||
</div>
|
||
</div>
|
||
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
<footer>
|
||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||
<a href="04-low-memory-dropout.html" class="btn btn-neutral float-right" title="Low-Memory Dropout" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||
<a href="02-fused-softmax.html" class="btn btn-neutral float-left" title="Fused Softmax" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||
</div>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<p>
|
||
© Copyright 2020, Philippe Tillet.
|
||
|
||
</p>
|
||
</div>
|
||
|
||
|
||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||
|
||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||
|
||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
|
||
</section>
|
||
|
||
</div>
|
||
|
||
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
|
||
<span class="rst-current-version" data-toggle="rst-current-version">
|
||
<span class="fa fa-book"> Other Versions</span>
|
||
v: master
|
||
<span class="fa fa-caret-down"></span>
|
||
</span>
|
||
<div class="rst-other-versions">
|
||
<dl>
|
||
<dt>Tags</dt>
|
||
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
|
||
</dl>
|
||
<dl>
|
||
<dt>Branches</dt>
|
||
<dd><a href="03-matrix-multiplication.html">master</a></dd>
|
||
</dl>
|
||
</div>
|
||
</div>
|
||
|
||
<script type="text/javascript">
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</body>
|
||
</html> |