639 lines
69 KiB
HTML
639 lines
69 KiB
HTML
|
||
|
||
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" >
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
|
||
<title>Matrix Multiplication — Triton documentation</title>
|
||
|
||
|
||
|
||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<!--[if lt IE 9]>
|
||
<script src="../../_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
|
||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||
<script src="../../_static/jquery.js"></script>
|
||
<script src="../../_static/underscore.js"></script>
|
||
<script src="../../_static/doctools.js"></script>
|
||
|
||
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
||
|
||
|
||
<link rel="index" title="Index" href="../../genindex.html" />
|
||
<link rel="search" title="Search" href="../../search.html" />
|
||
<link rel="next" title="Introduction" href="../../programming-guide/chapter-1/introduction.html" />
|
||
<link rel="prev" title="Fused Softmax" href="02-fused-softmax.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
|
||
|
||
<div class="wy-grid-for-nav">
|
||
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../../index.html" class="icon icon-home"> Triton
|
||
|
||
|
||
|
||
</a>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
|
||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||
<ul class="current">
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
|
||
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
|
||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Matrix Multiplication</a><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="#motivations">Motivations</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a><ul>
|
||
<li class="toctree-l4"><a class="reference internal" href="#pointer-arithmetics">Pointer Arithmetics</a></li>
|
||
<li class="toctree-l4"><a class="reference internal" href="#l2-cache-optimizations">L2 Cache Optimizations</a></li>
|
||
<li class="toctree-l4"><a class="reference internal" href="#final-result">Final Result</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#torch-bindings">Torch Bindings</a><ul>
|
||
<li class="toctree-l4"><a class="reference internal" href="#auto-tuning">Auto-Tuning</a></li>
|
||
<li class="toctree-l4"><a class="reference internal" href="#autograd-function">Autograd Function</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a><ul>
|
||
<li class="toctree-l4"><a class="reference internal" href="#installing-the-cutlass-bindings">Installing The CUTLASS Bindings</a></li>
|
||
<li class="toctree-l4"><a class="reference internal" href="#square-matrix-performance">Square Matrix Performance</a></li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-3/triton-c.html">The Triton-C Language</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-4/triton-ir.html">The Triton-IR Intermediate Representation</a></li>
|
||
</ul>
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||
|
||
|
||
<nav class="wy-nav-top" aria-label="top navigation">
|
||
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../../index.html">Triton</a>
|
||
|
||
</nav>
|
||
|
||
|
||
<div class="wy-nav-content">
|
||
|
||
<div class="rst-content">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||
|
||
<ul class="wy-breadcrumbs">
|
||
|
||
<li><a href="../../index.html" class="icon icon-home"></a> »</li>
|
||
|
||
<li><a href="index.html">Tutorials</a> »</li>
|
||
|
||
<li>Matrix Multiplication</li>
|
||
|
||
|
||
<li class="wy-breadcrumbs-aside">
|
||
|
||
|
||
<a href="../../_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt" rel="nofollow"> View page source</a>
|
||
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
|
||
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<div class="sphx-glr-download-link-note admonition note">
|
||
<p class="admonition-title">Note</p>
|
||
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">here</span></a>
|
||
to download the full example code</p>
|
||
</div>
|
||
<div class="sphx-glr-example-title section" id="matrix-multiplication">
|
||
<span id="sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"></span><h1>Matrix Multiplication<a class="headerlink" href="#matrix-multiplication" title="Permalink to this headline">¶</a></h1>
|
||
<p>In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that outperforms CUTLASS and falls just short of matching cuBLAS’s performance.
|
||
You will specifically learn about:</p>
|
||
<ul class="simple">
|
||
<li><p>The block-level matrix multiplication operator <cite>@</cite></p></li>
|
||
<li><p>Multi-dimensional pointer arithmetic</p></li>
|
||
<li><p>Program re-ordering for improved L2 cache hit rate</p></li>
|
||
<li><p>Automatic performance tuning</p></li>
|
||
</ul>
|
||
<div class="section" id="motivations">
|
||
<h2>Motivations<a class="headerlink" href="#motivations" title="Permalink to this headline">¶</a></h2>
|
||
<p>Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||
They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called “kernel libraries” (e.g., cuBLAS).
|
||
Unfortunately, these libraries are often proprietary and cannot be customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
|
||
For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.</p>
|
||
<p>Roughly speaking, the kernel that we will write will implement the following blocked algorithm:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># do in parallel</span>
|
||
<span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">MB</span><span class="p">):</span>
|
||
<span class="c1"># do in parallel</span>
|
||
<span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">NB</span><span class="p">):</span>
|
||
<span class="n">acc</span> <span class="o">=</span> <span class="n">zeros</span><span class="p">((</span><span class="n">MB</span><span class="p">,</span> <span class="n">NB</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">KB</span><span class="p">):</span>
|
||
<span class="n">acc</span> <span class="o">+=</span> <span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">MB</span><span class="p">,</span> <span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">KB</span><span class="p">]</span> <span class="o">@</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">KB</span><span class="p">,</span> <span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">NB</span><span class="p">]</span>
|
||
<span class="n">C</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">MB</span><span class="p">,</span> <span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="n">acc</span><span class="p">;</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>where each iteration of the doubly-nested for-loops corresponds to a Triton program instance.</p>
|
||
</div>
|
||
<div class="section" id="compute-kernel">
|
||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||
<p>The above algorithm is actually fairly straightforward to implement in Triton, as we can simply use the <code class="code docutils literal notranslate"><span class="pre">@</span></code> operator for block-level matrix multiplication.
|
||
The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations of the tiles of <code class="code docutils literal notranslate"><span class="pre">A</span></code> and <code class="code docutils literal notranslate"><span class="pre">B</span></code> that we need to read in the inner loop.</p>
|
||
<div class="section" id="pointer-arithmetics">
|
||
<h3>Pointer Arithmetics<a class="headerlink" href="#pointer-arithmetics" title="Permalink to this headline">¶</a></h3>
|
||
<p>For a row-major 2D tensor <code class="code docutils literal notranslate"><span class="pre">X</span></code>, the memory location of <code class="code docutils literal notranslate"><span class="pre">X[i,</span> <span class="pre">j]</span></code> is given by <code class="code docutils literal notranslate"><span class="pre">&X[i,</span> <span class="pre">j]</span> <span class="pre">=</span> <span class="pre">i</span> <span class="pre">+</span> <span class="pre">X.stride(0)</span> <span class="pre">+</span> <span class="pre">j</span></code>.
|
||
Therefore, blocks of pointers for <code class="code docutils literal notranslate"><span class="pre">A[m</span> <span class="pre">:</span> <span class="pre">m+MB,</span> <span class="pre">k:k+KB]</span></code> and <code class="code docutils literal notranslate"><span class="pre">B[k</span> <span class="pre">:</span> <span class="pre">k+KB,</span> <span class="pre">n</span> <span class="pre">:</span> <span class="pre">n+NB]</span></code> can be defined in pseudo-code as:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">&</span><span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">MB</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span><span class="n">k</span><span class="o">+</span><span class="n">KB</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">MB</span><span class="p">)[:,</span> <span class="n">newaxis</span><span class="p">]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">KB</span><span class="p">)[</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:];</span>
|
||
<span class="o">&</span><span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">KB</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span><span class="n">n</span><span class="o">+</span><span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">KB</span><span class="p">)[:,</span> <span class="n">newaxis</span><span class="p">]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">NB</span><span class="p">)[</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:];</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>Which means that, at initialization (i.e., <code class="code docutils literal notranslate"><span class="pre">k</span> <span class="pre">=</span> <span class="pre">0</span></code>), pointers for blocks of A and B can be initialized in Triton as:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="kt">int</span> <span class="n">rm</span><span class="p">[</span><span class="n">MB</span><span class="p">]</span> <span class="o">=</span> <span class="n">program_id_m</span> <span class="o">*</span> <span class="n">MB</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">MB</span><span class="p">;</span>
|
||
<span class="kt">int</span> <span class="n">rn</span><span class="p">[</span><span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="n">program_id_n</span> <span class="o">*</span> <span class="n">NB</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">NB</span><span class="p">;</span>
|
||
<span class="kt">int</span> <span class="n">rk</span><span class="p">[</span><span class="n">KB</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">KB</span><span class="p">;</span>
|
||
<span class="n">TYPE</span> <span class="o">*</span><span class="n">pa</span><span class="p">[</span><span class="n">MB</span><span class="p">,</span> <span class="n">KB</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">rm</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="n">newaxis</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_a_0</span> <span class="o">+</span> <span class="n">rk</span> <span class="p">[</span><span class="n">newaxis</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o">*</span> <span class="mi">1</span><span class="p">);</span>
|
||
<span class="n">TYPE</span> <span class="o">*</span><span class="n">pb</span><span class="p">[</span><span class="n">KB</span><span class="p">,</span> <span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">rk</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="n">newaxis</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_b_0</span> <span class="o">+</span> <span class="n">rn</span> <span class="p">[</span><span class="n">newaxis</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o">*</span> <span class="mi">1</span><span class="p">);</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>These pointers can then be updated in the inner loop as:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="n">pa</span> <span class="o">+=</span> <span class="n">KB</span> <span class="o">*</span> <span class="mi">1</span><span class="p">;</span>
|
||
<span class="n">pb</span> <span class="o">+=</span> <span class="n">KB</span> <span class="o">*</span> <span class="n">ldb</span><span class="p">;</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
</div>
|
||
<div class="section" id="l2-cache-optimizations">
|
||
<h3>L2 Cache Optimizations<a class="headerlink" href="#l2-cache-optimizations" title="Permalink to this headline">¶</a></h3>
|
||
<p>As mentioned above, each program instance computes an <code class="code docutils literal notranslate"><span class="pre">[MB,</span> <span class="pre">NB]</span></code> block of <code class="code docutils literal notranslate"><span class="pre">C</span></code>.
|
||
However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
|
||
This means that a naive row-major ordering:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="kt">int</span> <span class="n">program_id</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||
<span class="kt">int</span> <span class="n">grid_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">MB</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">MB</span><span class="p">;</span>
|
||
<span class="kt">int</span> <span class="n">grid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="n">NB</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">NB</span><span class="p">;</span>
|
||
<span class="kt">int</span> <span class="n">program_id_m</span> <span class="o">=</span> <span class="n">program_id</span> <span class="o">/</span> <span class="n">grid_n</span><span class="p">;</span>
|
||
<span class="kt">int</span> <span class="n">program_id_n</span> <span class="o">=</span> <span class="n">program_id</span> <span class="o">%</span> <span class="n">grid_n</span><span class="p">;</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>is unlikely to result in optimal performance.</p>
|
||
<p>One possible solution is to launch blocks in an order that promotes data reuse.
|
||
This can be done by ‘super-grouping’ blocks in groups of <code class="code docutils literal notranslate"><span class="pre">GROUP_SIZE</span></code> before switching to the next column:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="kt">int</span> <span class="n">program_id</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||
<span class="kt">int</span> <span class="n">width</span> <span class="o">=</span> <span class="n">GROUP_SIZE</span> <span class="o">*</span> <span class="n">grid_n</span><span class="p">;</span>
|
||
<span class="kt">int</span> <span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">/</span> <span class="n">width</span><span class="p">;</span>
|
||
<span class="c1">// we need to handle the case where M % (GROUP_SIZE*BM) != 0</span>
|
||
<span class="kt">int</span> <span class="n">group_size</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">grid_m</span> <span class="o">-</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_SIZE</span><span class="p">,</span> <span class="n">GROUP_SIZE</span><span class="p">);</span>
|
||
<span class="kt">int</span> <span class="n">pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_SIZE</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size</span><span class="p">);</span>
|
||
<span class="kt">int</span> <span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">width</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">group_size</span><span class="p">);</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>In practice, this can improve the performance of our matrix multiplication kernel by >10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</p>
|
||
</div>
|
||
<div class="section" id="final-result">
|
||
<h3>Final Result<a class="headerlink" href="#final-result" title="Permalink to this headline">¶</a></h3>
|
||
<p>We are now ready to put all these pieces together and write our Triton kernel for matrix multiplication.
|
||
Note that we rematerialize <code class="code docutils literal notranslate"><span class="pre">rm</span></code> and <code class="code docutils literal notranslate"><span class="pre">rn:</span></code> after the inner loop to decrease register pressure.
|
||
This is an optimization that provides an additional 5% performance improvement and cannot be currently done by the Triton compiler.</p>
|
||
<blockquote>
|
||
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="cp">#define MAX_GROUP_SIZE 8</span>
|
||
|
||
<span class="n">__global__</span> <span class="kt">void</span> <span class="n">dot</span><span class="p">(</span><span class="n">TYPE</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="n">TYPE</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="n">TYPE</span><span class="o">*</span> <span class="n">C</span><span class="p">,</span>
|
||
<span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">,</span> <span class="kt">int</span> <span class="n">K</span><span class="p">,</span>
|
||
<span class="kt">int</span> <span class="n">stride_a_0</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_b_0</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_c_0</span><span class="p">)</span> <span class="p">{</span>
|
||
<span class="c1">// prologue</span>
|
||
<span class="kt">int</span> <span class="n">pid</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||
<span class="kt">int</span> <span class="n">grid_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">MB</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">MB</span><span class="p">;</span>
|
||
<span class="kt">int</span> <span class="n">grid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="n">NB</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">NB</span><span class="p">;</span>
|
||
<span class="c1">// re-order program ID for better L2 performance</span>
|
||
<span class="kt">int</span> <span class="n">width</span> <span class="o">=</span> <span class="n">MAX_GROUP_SIZE</span> <span class="o">*</span> <span class="n">grid_n</span><span class="p">;</span>
|
||
<span class="kt">int</span> <span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">/</span> <span class="n">width</span><span class="p">;</span>
|
||
<span class="kt">int</span> <span class="n">group_size</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">grid_m</span> <span class="o">-</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">MAX_GROUP_SIZE</span><span class="p">,</span> <span class="n">MAX_GROUP_SIZE</span><span class="p">);</span>
|
||
<span class="kt">int</span> <span class="n">pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">MAX_GROUP_SIZE</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size</span><span class="p">);</span>
|
||
<span class="kt">int</span> <span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">width</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">group_size</span><span class="p">);</span>
|
||
<span class="c1">// pointers to operands</span>
|
||
<span class="c1">// note the parentheses here; they force the offset</span>
|
||
<span class="c1">// computation to happen in typeof(stride_a_0) = int32 rather than</span>
|
||
<span class="c1">// typeof(A) = int64</span>
|
||
<span class="kt">int</span> <span class="n">rm</span><span class="p">[</span><span class="n">MB</span><span class="p">]</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">MB</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">MB</span><span class="p">;</span>
|
||
<span class="kt">int</span> <span class="n">rn</span><span class="p">[</span><span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">NB</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">NB</span><span class="p">;</span>
|
||
<span class="kt">int</span> <span class="n">rk</span><span class="p">[</span><span class="n">KB</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">KB</span><span class="p">;</span>
|
||
<span class="n">TYPE</span> <span class="o">*</span><span class="n">pa</span><span class="p">[</span><span class="n">MB</span><span class="p">,</span> <span class="n">KB</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">rk</span> <span class="p">[</span><span class="n">newaxis</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o">*</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">rm</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="n">newaxis</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_a_0</span><span class="p">);</span>
|
||
<span class="n">TYPE</span> <span class="o">*</span><span class="n">pb</span><span class="p">[</span><span class="n">KB</span><span class="p">,</span> <span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">rk</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="n">newaxis</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_b_0</span> <span class="o">+</span> <span class="n">rn</span> <span class="p">[</span><span class="n">newaxis</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o">*</span> <span class="mi">1</span><span class="p">);</span>
|
||
<span class="c1">// reduction loop</span>
|
||
<span class="kt">float</span> <span class="n">acc</span><span class="p">[</span><span class="n">MB</span><span class="p">,</span> <span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
|
||
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="n">K</span><span class="p">;</span> <span class="n">k</span> <span class="o">></span> <span class="mi">0</span><span class="p">;</span> <span class="n">k</span> <span class="o">-=</span> <span class="n">KB</span><span class="p">)</span> <span class="p">{</span>
|
||
<span class="n">acc</span> <span class="o">+=</span> <span class="p">(</span><span class="o">*</span><span class="n">pa</span><span class="p">)</span> <span class="err">@</span> <span class="p">(</span><span class="o">*</span><span class="n">pb</span><span class="p">);</span>
|
||
<span class="n">pa</span> <span class="o">+=</span> <span class="n">KB</span> <span class="o">*</span> <span class="mi">1</span><span class="p">;</span>
|
||
<span class="n">pb</span> <span class="o">+=</span> <span class="n">KB</span> <span class="o">*</span> <span class="n">stride_b_0</span><span class="p">;</span>
|
||
<span class="p">}</span>
|
||
<span class="c1">// pointers to output</span>
|
||
<span class="c1">// here we rematerialize `rm` and `rn` so that they are not live through</span>
|
||
<span class="c1">// the above reduction loop. In the future, the compiler should be able to</span>
|
||
<span class="c1">// do this automatically.</span>
|
||
<span class="n">rm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">MB</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">MB</span><span class="p">;</span>
|
||
<span class="n">rn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">NB</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">NB</span><span class="p">;</span>
|
||
<span class="n">TYPE</span> <span class="o">*</span><span class="n">pc</span><span class="p">[</span><span class="n">MB</span><span class="p">,</span> <span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="n">C</span> <span class="o">+</span> <span class="p">(</span><span class="n">rm</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="n">newaxis</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_c_0</span> <span class="o">+</span> <span class="n">rn</span><span class="p">[</span><span class="n">newaxis</span><span class="p">,</span> <span class="o">:</span><span class="p">]);</span>
|
||
<span class="c1">// we write back using *?() operator. `acc` gets casted to `float32` implicitly.</span>
|
||
<span class="o">*?</span> <span class="p">(</span><span class="n">rm</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="n">newaxis</span><span class="p">]</span> <span class="o"><</span> <span class="n">M</span> <span class="o">&&</span> <span class="n">rn</span> <span class="p">[</span><span class="n">newaxis</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span> <span class="n">pc</span> <span class="o">=</span> <span class="n">acc</span><span class="p">;</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>Where <code class="code docutils literal notranslate"><span class="pre">TYPE</span></code> is the data-type of the input matrices and <code class="code docutils literal notranslate"><span class="pre">MB</span></code>, <code class="code docutils literal notranslate"><span class="pre">NB</span></code>, <code class="code docutils literal notranslate"><span class="pre">KB</span></code> are the block sizes defined in the above pseudo-code.
|
||
Good values for these block sizes are hard to find, hence we will introduce the auto-tuner in the next section of this tutorial.
|
||
If <code class="code docutils literal notranslate"><span class="pre">TYPE</span></code> is <code class="code docutils literal notranslate"><span class="pre">half</span></code>, then tensor cores will be used automatically provided that <code class="code docutils literal notranslate"><span class="pre">MB</span></code>, <code class="code docutils literal notranslate"><span class="pre">NB</span></code> and <code class="code docutils literal notranslate"><span class="pre">KB</span></code> are multiples of 16.</p>
|
||
</div>
|
||
</div>
|
||
<div class="section" id="torch-bindings">
|
||
<h2>Torch Bindings<a class="headerlink" href="#torch-bindings" title="Permalink to this headline">¶</a></h2>
|
||
<div class="section" id="auto-tuning">
|
||
<h3>Auto-Tuning<a class="headerlink" href="#auto-tuning" title="Permalink to this headline">¶</a></h3>
|
||
<p>In order to use Triton’s built-in auto-tuner in the above kernel, we need to define a list of <code class="code docutils literal notranslate"><span class="pre">triton.config</span></code> objects. that can be constructed as follows:</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||
<span class="kn">import</span> <span class="nn">triton</span>
|
||
|
||
<span class="n">autotune_configs</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s2">"MB"</span><span class="p">:</span> <span class="s2">"128"</span><span class="p">,</span> <span class="s2">"NB"</span><span class="p">:</span> <span class="s2">"128"</span><span class="p">,</span> <span class="s2">"KB"</span><span class="p">:</span> <span class="s2">"32"</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span> <span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'128'</span><span class="p">,</span> <span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'128'</span><span class="p">,</span> <span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span> <span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span> <span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span> <span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">,</span> <span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'128'</span><span class="p">,</span> <span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'128'</span><span class="p">,</span> <span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">,</span> <span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span> <span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">,</span> <span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">,</span> <span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span> <span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>we also need to define a list of <code class="code docutils literal notranslate"><span class="pre">string</span></code> (i.e., “autotuning key”) that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in.
|
||
Here, we want to re-tune our kernel only when the shape of input matrices changes.</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">autotune_key</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"M"</span><span class="p">,</span> <span class="s2">"N"</span><span class="p">,</span> <span class="s2">"K"</span><span class="p">]</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>We can now create an auto-tuned kernel by passing the <cite>autotune_configs</cite> and <cite>autotune_key</cite> lists to the constructor of the <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code> class.</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">src</span> <span class="o">=</span> <span class="s2">"""</span>
|
||
<span class="s2">#define MAX_GROUP_SIZE 8</span>
|
||
|
||
<span class="s2">__global__ void dot(TYPE* A, TYPE* B, TYPE* C,</span>
|
||
<span class="s2"> int M, int N, int K,</span>
|
||
<span class="s2"> int lda, int ldb, int ldc) {</span>
|
||
<span class="s2"> int pid = get_program_id(0);</span>
|
||
<span class="s2"> int grid_m = (M + MB - 1) / MB;</span>
|
||
<span class="s2"> int grid_n = (N + NB - 1) / NB;</span>
|
||
<span class="s2"> int width = MAX_GROUP_SIZE * grid_n;</span>
|
||
<span class="s2"> int group_id = pid / width;</span>
|
||
<span class="s2"> int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);</span>
|
||
<span class="s2"> int pid_m = group_id * MAX_GROUP_SIZE + (pid </span><span class="si">% g</span><span class="s2">roup_size);</span>
|
||
<span class="s2"> int pid_n = (pid % width) / (group_size);</span>
|
||
<span class="s2"> int rm[MB] = pid_m * MB + 0 ... MB;</span>
|
||
<span class="s2"> int rn[NB] = pid_n * NB + 0 ... NB;</span>
|
||
<span class="s2"> int rk[KB] = 0 ... KB;</span>
|
||
<span class="s2"> TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * lda);</span>
|
||
<span class="s2"> TYPE *pb[KB, NB] = B + (rk[:, newaxis] * ldb + rn [newaxis, :] * 1);</span>
|
||
<span class="s2"> float acc[MB, NB] = 0;</span>
|
||
<span class="s2"> for (int k = K; k > 0; k -= KB) {</span>
|
||
<span class="s2"> acc += (*pa) @ (*pb);</span>
|
||
<span class="s2"> pa += KB * 1;</span>
|
||
<span class="s2"> pb += KB * ldb;</span>
|
||
<span class="s2"> }</span>
|
||
<span class="s2"> rm = pid_m * MB + 0 ... MB;</span>
|
||
<span class="s2"> rn = pid_n * NB + 0 ... NB;</span>
|
||
<span class="s2"> TYPE *pc[MB, NB] = C + (rm[:, newaxis] * ldc + rn[newaxis, :]);</span>
|
||
<span class="s2"> *? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;</span>
|
||
<span class="s2">}</span>
|
||
<span class="s2">"""</span>
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">make_kernel</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
|
||
<span class="n">key</span> <span class="o">=</span> <span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">cache</span> <span class="o">=</span> <span class="n">make_kernel</span><span class="o">.</span><span class="n">cache</span>
|
||
<span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">cache</span><span class="p">:</span>
|
||
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'TYPE'</span><span class="p">:</span> <span class="n">dtype</span><span class="p">}</span>
|
||
<span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">,</span> <span class="n">autotune_vals</span><span class="o">=</span><span class="n">autotune_configs</span><span class="p">,</span> <span class="n">autotune_key</span><span class="o">=</span><span class="n">autotune_key</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
|
||
|
||
|
||
<span class="n">make_kernel</span><span class="o">.</span><span class="n">cache</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="section" id="autograd-function">
|
||
<h3>Autograd Function<a class="headerlink" href="#autograd-function" title="Permalink to this headline">¶</a></h3>
|
||
<p>Now we are ready to expose our auto-tuned kernel as a <cite>torch.autograd.Function</cite>.
|
||
To do so, we just need to define a <cite>forward</cite> function that takes a two tensors as input and returns a tensor as output.</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">_dot</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="n">M</span><span class="p">,</span> <span class="n">Ka</span> <span class="o">=</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span>
|
||
<span class="n">Kb</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span>
|
||
<span class="k">assert</span> <span class="n">Ka</span> <span class="o">==</span> <span class="n">Kb</span><span class="p">,</span> <span class="s2">"incompatible dimensions"</span>
|
||
<span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">()</span> <span class="ow">and</span> <span class="n">b</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(),</span> <span class="s2">"inputs must be contiguous"</span>
|
||
<span class="n">c</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_kernel</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">opt</span><span class="o">.</span><span class="n">MB</span><span class="p">)</span> <span class="o">*</span> <span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">opt</span><span class="o">.</span><span class="n">NB</span><span class="p">),</span> <span class="p">)</span>
|
||
<span class="n">kernel</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">b</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">c</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> \
|
||
<span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">Ka</span><span class="p">,</span> \
|
||
<span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> \
|
||
<span class="n">grid</span><span class="o">=</span><span class="n">grid</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">c</span>
|
||
|
||
|
||
<span class="n">dot</span> <span class="o">=</span> <span class="n">_dot</span><span class="o">.</span><span class="n">apply</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
<div class="section" id="unit-test">
|
||
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline">¶</a></h2>
|
||
<p>We can test our custom matrix multiplication operation against cuBLAS (i.e., <code class="code docutils literal notranslate"><span class="pre">torch.matmul</span></code>).
|
||
Note that we need to modify the :code`atol` and <code class="code docutils literal notranslate"><span class="pre">rtol</span></code> parameters of <cite>torch.allclose</cite> to account for the fact that we are comparing FP16 tensors.</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">768</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">((</span><span class="mi">768</span><span class="p">,</span> <span class="mi">896</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">c_0</span> <span class="o">=</span> <span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
<span class="n">c_1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">c_0</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">c_1</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">c_0</span><span class="p">,</span> <span class="n">c_1</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">))</span>
|
||
</pre></div>
|
||
</div>
|
||
<p class="sphx-glr-script-out">Out:</p>
|
||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[199.6250, 198.0000, 195.0000, ..., 186.0000, 193.6250, 202.1250],
|
||
[192.6250, 193.6250, 190.7500, ..., 184.2500, 191.2500, 192.1250],
|
||
[192.3750, 196.6250, 188.8750, ..., 185.5000, 188.7500, 191.8750],
|
||
...,
|
||
[196.6250, 199.8750, 196.1250, ..., 182.6250, 194.5000, 200.8750],
|
||
[199.2500, 200.3750, 191.7500, ..., 186.8750, 192.8750, 193.5000],
|
||
[193.5000, 195.2500, 194.1250, ..., 188.3750, 192.6250, 198.3750]],
|
||
device='cuda:0', dtype=torch.float16)
|
||
tensor([[199.6250, 198.0000, 195.0000, ..., 186.0000, 193.6250, 202.1250],
|
||
[192.6250, 193.6250, 190.7500, ..., 184.2500, 191.2500, 192.1250],
|
||
[192.3750, 196.6250, 188.8750, ..., 185.5000, 188.7500, 191.8750],
|
||
...,
|
||
[196.6250, 199.8750, 196.1250, ..., 182.6250, 194.5000, 200.8750],
|
||
[199.2500, 200.3750, 191.7500, ..., 186.8750, 192.8750, 193.5000],
|
||
[193.5000, 195.2500, 194.1250, ..., 188.3750, 192.6250, 198.3750]],
|
||
device='cuda:0', dtype=torch.float16)
|
||
True
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="section" id="benchmark">
|
||
<h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline">¶</a></h2>
|
||
<div class="section" id="installing-the-cutlass-bindings">
|
||
<h3>Installing The CUTLASS Bindings<a class="headerlink" href="#installing-the-cutlass-bindings" title="Permalink to this headline">¶</a></h3>
|
||
<p>The cuBLAS library (used by <code class="code docutils literal notranslate"><span class="pre">torch.matmul</span></code>) uses handwritten assembly-level optimizations that cannot be replicated using publicly available tools.
|
||
For this reason, we will instead compare the performance of our kernel against <a class="reference external" href="https://github.com/NVIDIA/cutlass/">CUTLASS</a> , a highly optimized CUDA library for matrix multiplication written by NVIDIA themselves._
|
||
To install CUTLASS, you need a recent version of cmake:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">cd</span> /path/to/cutlass/
|
||
git clone https://github.com/NVIDIA/cutlass.git
|
||
<span class="nb">cd</span> cutlass
|
||
mkdir build
|
||
<span class="nb">cd</span> build
|
||
wget https://github.com/Kitware/CMake/releases/download/v3.19.4/cmake-3.19.4-Linux-x86_64.tar.gz
|
||
tar xzvf *.tar.gz
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>You can then install CUTLASS as follows for V100</p>
|
||
<blockquote>
|
||
<div><div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED<span class="o">=</span><span class="m">70</span> -DCUTLASS_LIBRARY_KERNELS<span class="o">=</span>cutlass_tensorop_f16_s884gemm_f16_*_align8
|
||
make -j8 install
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>Or as follows for A100:</p>
|
||
<blockquote>
|
||
<div><div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED<span class="o">=</span><span class="m">80</span> -DCUTLASS_LIBRARY_KERNELS<span class="o">=</span>cutlass_tensorop_f16_s16816gemm_*align8
|
||
make -j8 install
|
||
</pre></div>
|
||
</div>
|
||
</div></blockquote>
|
||
<p>Where you can change CUTLASS_LIBRARY_KERNELS as you desire. Here, we are only interested in FP16 tensor core performance.
|
||
Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables <code class="code docutils literal notranslate"><span class="pre">CUTLASS_INCLUDE_DIR</span></code> and <code class="code docutils literal notranslate"><span class="pre">CUTLASS_LIBRARY_DIR</span></code> are set during the installation process.
|
||
To re-install Triton with the updated CUTLASS bindings, run the following command:</p>
|
||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">export</span> <span class="nv">CUTLASS_INCLUDE_DIR</span><span class="o">=</span>/tmp/cutlass/build/install/include/
|
||
<span class="nb">export</span> <span class="nv">CUTLASS_LIBRARY_DIR</span><span class="o">=</span>/tmp/cutlass/build/install/lib/a
|
||
pip uninstall -y triton
|
||
pip install -e <span class="s2">"git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>Which we can test as follows:</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">triton</span>
|
||
<span class="n">c_2</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">cutlass_matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">c_2</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">c_0</span><span class="p">,</span> <span class="n">c_2</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">))</span>
|
||
</pre></div>
|
||
</div>
|
||
<p class="sphx-glr-script-out">Out:</p>
|
||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[199.6250, 198.0000, 195.0000, ..., 186.0000, 193.6250, 202.1250],
|
||
[192.6250, 193.6250, 190.7500, ..., 184.2500, 191.2500, 192.1250],
|
||
[192.3750, 196.6250, 188.8750, ..., 185.5000, 188.7500, 191.8750],
|
||
...,
|
||
[196.6250, 199.8750, 196.1250, ..., 182.6250, 194.5000, 200.8750],
|
||
[199.2500, 200.3750, 191.7500, ..., 186.8750, 192.8750, 193.5000],
|
||
[193.5000, 195.2500, 194.1250, ..., 188.3750, 192.6250, 198.3750]],
|
||
device='cuda:0', dtype=torch.float16)
|
||
True
|
||
</pre></div>
|
||
</div>
|
||
<p>Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready.</p>
|
||
</div>
|
||
<div class="section" id="square-matrix-performance">
|
||
<h3>Square Matrix Performance<a class="headerlink" href="#square-matrix-performance" title="Permalink to this headline">¶</a></h3>
|
||
<p>We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
|
||
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'M'</span><span class="p">,</span> <span class="s1">'N'</span><span class="p">,</span> <span class="s1">'K'</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
|
||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">256</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">33</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
|
||
<span class="n">y_name</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||
<span class="n">y_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'torch'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'cutlass'</span><span class="p">],</span> <span class="c1"># possible keys for `y_name`</span>
|
||
<span class="n">y_lines</span><span class="o">=</span><span class="p">[</span><span class="s2">"Torch"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">,</span> <span class="s1">'CUTLASS'</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||
<span class="n">ylabel</span><span class="o">=</span><span class="s2">"TFLOPS"</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
||
<span class="n">plot_name</span><span class="o">=</span><span class="s2">"matmul-performance"</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
||
<span class="n">args</span><span class="o">=</span><span class="p">{}</span>
|
||
<span class="p">)</span>
|
||
<span class="p">)</span>
|
||
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
|
||
<span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">K</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'torch'</span><span class="p">:</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'triton'</span><span class="p">:</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'cutlass'</span><span class="p">:</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">cutlass_matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
||
<span class="n">perf</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">M</span> <span class="o">*</span> <span class="n">N</span> <span class="o">*</span> <span class="n">K</span> <span class="o">*</span> <span class="mf">1e-12</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">perf</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
|
||
|
||
|
||
<span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<img alt="matmul-performance" class="sphx-glr-single-img" src="../../_images/sphx_glr_03-matrix-multiplication_001.png" />
|
||
<p>As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.</p>
|
||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 6.502 seconds)</p>
|
||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
|
||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
|
||
</div>
|
||
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
|
||
<p><a class="reference download internal" download="" href="../../_downloads/b51b68bc1c6b1a5e509f67800b6235af/03-matrix-multiplication.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">03-matrix-multiplication.ipynb</span></code></a></p>
|
||
</div>
|
||
</div>
|
||
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
<footer>
|
||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||
<a href="../../programming-guide/chapter-1/introduction.html" class="btn btn-neutral float-right" title="Introduction" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||
<a href="02-fused-softmax.html" class="btn btn-neutral float-left" title="Fused Softmax" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||
</div>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<p>
|
||
© Copyright 2020, Philippe Tillet.
|
||
|
||
</p>
|
||
</div>
|
||
|
||
|
||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||
|
||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||
|
||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
|
||
</section>
|
||
|
||
</div>
|
||
|
||
|
||
<script type="text/javascript">
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</body>
|
||
</html> |